pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -1,17 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
5
6
|
|
6
7
|
import arviz as az
|
7
|
-
import ete3 as ete
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
11
12
|
import patsy as pt
|
13
|
+
import scanpy as sc
|
14
|
+
import seaborn as sns
|
15
|
+
from adjustText import adjust_text
|
12
16
|
from anndata import AnnData
|
13
|
-
from jax import random
|
14
|
-
from
|
17
|
+
from jax import config, random
|
18
|
+
from matplotlib import cm, rcParams
|
19
|
+
from matplotlib import image as mpimg
|
20
|
+
from matplotlib.colors import ListedColormap
|
15
21
|
from mudata import MuData
|
16
22
|
from numpyro.infer import HMC, MCMC, NUTS, initialization
|
17
23
|
from rich import box, print
|
@@ -20,10 +26,15 @@ from rich.table import Table
|
|
20
26
|
from scipy.cluster import hierarchy as sp_hierarchy
|
21
27
|
|
22
28
|
if TYPE_CHECKING:
|
29
|
+
from collections.abc import Sequence
|
30
|
+
|
23
31
|
import numpyro as npy
|
24
32
|
import toytree as tt
|
25
|
-
from
|
33
|
+
from ete3 import Tree
|
26
34
|
from jax._src.typing import Array
|
35
|
+
from matplotlib.axes import Axes
|
36
|
+
from matplotlib.colors import Colormap
|
37
|
+
from matplotlib.figure import Figure
|
27
38
|
|
28
39
|
config.update("jax_enable_x64", True)
|
29
40
|
|
@@ -179,7 +190,7 @@ class CompositionalModel2(ABC):
|
|
179
190
|
self,
|
180
191
|
sample_adata: AnnData,
|
181
192
|
kernel: npy.infer.mcmc.MCMCKernel,
|
182
|
-
rng_key: Array
|
193
|
+
rng_key: Array,
|
183
194
|
copy: bool = False,
|
184
195
|
*args,
|
185
196
|
**kwargs,
|
@@ -295,7 +306,7 @@ class CompositionalModel2(ABC):
|
|
295
306
|
if copy:
|
296
307
|
sample_adata = sample_adata.copy()
|
297
308
|
|
298
|
-
rng_key_array = random.
|
309
|
+
rng_key_array = random.key(rng_key)
|
299
310
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
|
300
311
|
|
301
312
|
# Set up NUTS kernel
|
@@ -335,7 +346,6 @@ class CompositionalModel2(ABC):
|
|
335
346
|
copy: Return a copy instead of writing to adata. Defaults to False.
|
336
347
|
|
337
348
|
Examples:
|
338
|
-
Example with scCODA:
|
339
349
|
>>> import pertpy as pt
|
340
350
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
341
351
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -358,10 +368,10 @@ class CompositionalModel2(ABC):
|
|
358
368
|
# Set rng key if needed
|
359
369
|
if rng_key is None:
|
360
370
|
rng = np.random.default_rng()
|
361
|
-
rng_key = random.
|
371
|
+
rng_key = random.key(rng.integers(0, 10000))
|
362
372
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
|
363
373
|
else:
|
364
|
-
rng_key = random.
|
374
|
+
rng_key = random.key(rng_key)
|
365
375
|
|
366
376
|
# Set up HMC kernel
|
367
377
|
sample_adata = self.set_init_mcmc_states(
|
@@ -423,7 +433,6 @@ class CompositionalModel2(ABC):
|
|
423
433
|
- Is credible: Boolean indicator whether effect is credible
|
424
434
|
|
425
435
|
Examples:
|
426
|
-
Example with scCODA:
|
427
436
|
>>> import pertpy as pt
|
428
437
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
429
438
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -433,7 +442,6 @@ class CompositionalModel2(ABC):
|
|
433
442
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
434
443
|
>>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
|
435
444
|
"""
|
436
|
-
# Get model and effect selection types
|
437
445
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
438
446
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
439
447
|
|
@@ -548,7 +556,11 @@ class CompositionalModel2(ABC):
|
|
548
556
|
intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
|
549
557
|
intercept_df = intercept_df.rename(
|
550
558
|
columns=dict(
|
551
|
-
zip(
|
559
|
+
zip(
|
560
|
+
intercept_df.columns,
|
561
|
+
["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
|
562
|
+
strict=False,
|
563
|
+
)
|
552
564
|
)
|
553
565
|
)
|
554
566
|
|
@@ -561,6 +573,7 @@ class CompositionalModel2(ABC):
|
|
561
573
|
zip(
|
562
574
|
effect_df.columns,
|
563
575
|
["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
|
576
|
+
strict=False,
|
564
577
|
)
|
565
578
|
)
|
566
579
|
)
|
@@ -581,6 +594,7 @@ class CompositionalModel2(ABC):
|
|
581
594
|
"Expected Sample",
|
582
595
|
"log2-fold change",
|
583
596
|
],
|
597
|
+
strict=False,
|
584
598
|
)
|
585
599
|
)
|
586
600
|
)
|
@@ -594,6 +608,7 @@ class CompositionalModel2(ABC):
|
|
594
608
|
zip(
|
595
609
|
node_df.columns,
|
596
610
|
["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
|
611
|
+
strict=False,
|
597
612
|
)
|
598
613
|
) # type: ignore
|
599
614
|
) # type: ignore
|
@@ -781,7 +796,6 @@ class CompositionalModel2(ABC):
|
|
781
796
|
kwargs: Passed to az.summary
|
782
797
|
|
783
798
|
Examples:
|
784
|
-
Example with scCODA:
|
785
799
|
>>> import pertpy as pt
|
786
800
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
787
801
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -799,7 +813,7 @@ class CompositionalModel2(ABC):
|
|
799
813
|
raise
|
800
814
|
if isinstance(data, AnnData):
|
801
815
|
sample_adata = data
|
802
|
-
|
816
|
+
|
803
817
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
804
818
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
805
819
|
|
@@ -926,7 +940,6 @@ class CompositionalModel2(ABC):
|
|
926
940
|
pd.DataFrame: Intercept data frame.
|
927
941
|
|
928
942
|
Examples:
|
929
|
-
Example with scCODA:
|
930
943
|
>>> import pertpy as pt
|
931
944
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
932
945
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -936,7 +949,6 @@ class CompositionalModel2(ABC):
|
|
936
949
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
937
950
|
>>> intercepts = sccoda.get_intercept_df(mdata)
|
938
951
|
"""
|
939
|
-
|
940
952
|
if isinstance(data, MuData):
|
941
953
|
try:
|
942
954
|
sample_adata = data[modality_key]
|
@@ -959,7 +971,6 @@ class CompositionalModel2(ABC):
|
|
959
971
|
pd.DataFrame: Effect data frame.
|
960
972
|
|
961
973
|
Examples:
|
962
|
-
Example with scCODA:
|
963
974
|
>>> import pertpy as pt
|
964
975
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
965
976
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -969,7 +980,6 @@ class CompositionalModel2(ABC):
|
|
969
980
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
970
981
|
>>> effects = sccoda.get_effect_df(mdata)
|
971
982
|
"""
|
972
|
-
|
973
983
|
if isinstance(data, MuData):
|
974
984
|
try:
|
975
985
|
sample_adata = data[modality_key]
|
@@ -1003,9 +1013,8 @@ class CompositionalModel2(ABC):
|
|
1003
1013
|
pd.DataFrame: Node effect data frame.
|
1004
1014
|
|
1005
1015
|
Examples:
|
1006
|
-
Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
|
1007
1016
|
>>> import pertpy as pt
|
1008
|
-
>>> adata = pt.dt.
|
1017
|
+
>>> adata = pt.dt.tasccoda_example()
|
1009
1018
|
>>> tasccoda = pt.tl.Tasccoda()
|
1010
1019
|
>>> mdata = tasccoda.load(
|
1011
1020
|
>>> adata, type="sample_level",
|
@@ -1113,6 +1122,1136 @@ class CompositionalModel2(ABC):
|
|
1113
1122
|
|
1114
1123
|
return out
|
1115
1124
|
|
1125
|
+
def _stackbar( # pragma: no cover
|
1126
|
+
self,
|
1127
|
+
y: np.ndarray,
|
1128
|
+
type_names: list[str],
|
1129
|
+
title: str,
|
1130
|
+
level_names: list[str],
|
1131
|
+
figsize: tuple[float, float] | None = None,
|
1132
|
+
dpi: int | None = 100,
|
1133
|
+
palette: ListedColormap | None = cm.tab20,
|
1134
|
+
show_legend: bool | None = True,
|
1135
|
+
) -> plt.Axes:
|
1136
|
+
"""Plots a stacked barplot for one (discrete) covariate.
|
1137
|
+
|
1138
|
+
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
1139
|
+
|
1140
|
+
Args:
|
1141
|
+
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
1142
|
+
one for each group, containing the count mean of each cell type
|
1143
|
+
type_names: The names of all cell types
|
1144
|
+
title: Plot title, usually the covariate's name
|
1145
|
+
level_names: Names of the covariate's levels
|
1146
|
+
figsize: Figure size. Defaults to None.
|
1147
|
+
dpi: Dpi setting. Defaults to 100.
|
1148
|
+
palette: The color map for the barplot. Defaults to cm.tab20.
|
1149
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1150
|
+
|
1151
|
+
Returns:
|
1152
|
+
A :class:`~matplotlib.axes.Axes` object
|
1153
|
+
"""
|
1154
|
+
n_bars, n_types = y.shape
|
1155
|
+
|
1156
|
+
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
1157
|
+
|
1158
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1159
|
+
r = np.array(range(n_bars))
|
1160
|
+
sample_sums = np.sum(y, axis=1)
|
1161
|
+
|
1162
|
+
barwidth = 0.85
|
1163
|
+
cum_bars = np.zeros(n_bars)
|
1164
|
+
|
1165
|
+
for n in range(n_types):
|
1166
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
1167
|
+
plt.bar(
|
1168
|
+
r,
|
1169
|
+
bars,
|
1170
|
+
bottom=cum_bars,
|
1171
|
+
color=palette(n % palette.N),
|
1172
|
+
width=barwidth,
|
1173
|
+
label=type_names[n],
|
1174
|
+
linewidth=0,
|
1175
|
+
)
|
1176
|
+
cum_bars += bars
|
1177
|
+
|
1178
|
+
ax.set_title(title)
|
1179
|
+
if show_legend:
|
1180
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
1181
|
+
ax.set_xticks(r)
|
1182
|
+
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
1183
|
+
ax.set_ylabel("Proportion")
|
1184
|
+
|
1185
|
+
return ax
|
1186
|
+
|
1187
|
+
def plot_stacked_barplot( # pragma: no cover
|
1188
|
+
self,
|
1189
|
+
data: AnnData | MuData,
|
1190
|
+
feature_name: str,
|
1191
|
+
modality_key: str = "coda",
|
1192
|
+
palette: ListedColormap | None = cm.tab20,
|
1193
|
+
show_legend: bool | None = True,
|
1194
|
+
level_order: list[str] = None,
|
1195
|
+
figsize: tuple[float, float] | None = None,
|
1196
|
+
dpi: int | None = 100,
|
1197
|
+
return_fig: bool | None = None,
|
1198
|
+
ax: plt.Axes | None = None,
|
1199
|
+
show: bool | None = None,
|
1200
|
+
save: str | bool | None = None,
|
1201
|
+
**kwargs,
|
1202
|
+
) -> plt.Axes | plt.Figure | None:
|
1203
|
+
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
1204
|
+
|
1205
|
+
Args:
|
1206
|
+
data: AnnData object or MuData object.
|
1207
|
+
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
1208
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1209
|
+
figsize: Figure size. Defaults to None.
|
1210
|
+
dpi: Dpi setting. Defaults to 100.
|
1211
|
+
palette: The matplotlib color map for the barplot. Defaults to cm.tab20.
|
1212
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1213
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1214
|
+
|
1215
|
+
Returns:
|
1216
|
+
A :class:`~matplotlib.axes.Axes` object
|
1217
|
+
|
1218
|
+
Examples:
|
1219
|
+
>>> import pertpy as pt
|
1220
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1221
|
+
>>> sccoda = pt.tl.Sccoda()
|
1222
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1223
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1224
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
1225
|
+
|
1226
|
+
Preview:
|
1227
|
+
.. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
|
1228
|
+
"""
|
1229
|
+
if isinstance(data, MuData):
|
1230
|
+
data = data[modality_key]
|
1231
|
+
if isinstance(data, AnnData):
|
1232
|
+
data = data
|
1233
|
+
|
1234
|
+
ct_names = data.var.index
|
1235
|
+
|
1236
|
+
# option to plot one stacked barplot per sample
|
1237
|
+
if feature_name == "samples":
|
1238
|
+
if level_order:
|
1239
|
+
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
1240
|
+
data = data[level_order]
|
1241
|
+
ax = self._stackbar(
|
1242
|
+
data.X,
|
1243
|
+
type_names=data.var.index,
|
1244
|
+
title="samples",
|
1245
|
+
level_names=data.obs.index,
|
1246
|
+
figsize=figsize,
|
1247
|
+
dpi=dpi,
|
1248
|
+
palette=palette,
|
1249
|
+
show_legend=show_legend,
|
1250
|
+
)
|
1251
|
+
else:
|
1252
|
+
# Order levels
|
1253
|
+
if level_order:
|
1254
|
+
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
1255
|
+
levels = level_order
|
1256
|
+
elif hasattr(data.obs[feature_name], "cat"):
|
1257
|
+
levels = data.obs[feature_name].cat.categories.to_list()
|
1258
|
+
else:
|
1259
|
+
levels = pd.unique(data.obs[feature_name])
|
1260
|
+
n_levels = len(levels)
|
1261
|
+
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
1262
|
+
|
1263
|
+
for level in range(n_levels):
|
1264
|
+
l_indices = np.where(data.obs[feature_name] == levels[level])
|
1265
|
+
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
1266
|
+
|
1267
|
+
ax = self._stackbar(
|
1268
|
+
feature_totals,
|
1269
|
+
type_names=ct_names,
|
1270
|
+
title=feature_name,
|
1271
|
+
level_names=levels,
|
1272
|
+
figsize=figsize,
|
1273
|
+
dpi=dpi,
|
1274
|
+
palette=palette,
|
1275
|
+
show_legend=show_legend,
|
1276
|
+
)
|
1277
|
+
|
1278
|
+
if save:
|
1279
|
+
plt.savefig(save, bbox_inches="tight")
|
1280
|
+
if show:
|
1281
|
+
plt.show()
|
1282
|
+
if return_fig:
|
1283
|
+
return plt.gcf()
|
1284
|
+
if not (show or save):
|
1285
|
+
return ax
|
1286
|
+
return None
|
1287
|
+
|
1288
|
+
def plot_effects_barplot( # pragma: no cover
|
1289
|
+
self,
|
1290
|
+
data: AnnData | MuData,
|
1291
|
+
modality_key: str = "coda",
|
1292
|
+
covariates: str | list | None = None,
|
1293
|
+
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
1294
|
+
plot_facets: bool = True,
|
1295
|
+
plot_zero_covariate: bool = True,
|
1296
|
+
plot_zero_cell_type: bool = False,
|
1297
|
+
palette: str | ListedColormap | None = cm.tab20,
|
1298
|
+
level_order: list[str] = None,
|
1299
|
+
args_barplot: dict | None = None,
|
1300
|
+
figsize: tuple[float, float] | None = None,
|
1301
|
+
dpi: int | None = 100,
|
1302
|
+
return_fig: bool | None = None,
|
1303
|
+
ax: plt.Axes | None = None,
|
1304
|
+
show: bool | None = None,
|
1305
|
+
save: str | bool | None = None,
|
1306
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1307
|
+
"""Barplot visualization for effects.
|
1308
|
+
|
1309
|
+
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
1310
|
+
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
1311
|
+
|
1312
|
+
Args:
|
1313
|
+
data: AnnData object or MuData object.
|
1314
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1315
|
+
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
1316
|
+
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
1317
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1318
|
+
Defaults to True.
|
1319
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
1320
|
+
Defaults to True.
|
1321
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
1322
|
+
Defaults to False.
|
1323
|
+
figsize: Figure size. Defaults to None.
|
1324
|
+
dpi: Figure size. Defaults to 100.
|
1325
|
+
palette: The seaborn color map for the barplot. Defaults to cm.tab20.
|
1326
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1327
|
+
args_barplot: Arguments passed to sns.barplot. Defaults to None.
|
1328
|
+
|
1329
|
+
Returns:
|
1330
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1331
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1332
|
+
|
1333
|
+
Examples:
|
1334
|
+
>>> import pertpy as pt
|
1335
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1336
|
+
>>> sccoda = pt.tl.Sccoda()
|
1337
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1338
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1339
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1340
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1341
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
1342
|
+
|
1343
|
+
Preview:
|
1344
|
+
.. image:: /_static/docstring_previews/sccoda_effects_barplot.png
|
1345
|
+
"""
|
1346
|
+
if args_barplot is None:
|
1347
|
+
args_barplot = {}
|
1348
|
+
if isinstance(data, MuData):
|
1349
|
+
data = data[modality_key]
|
1350
|
+
if isinstance(data, AnnData):
|
1351
|
+
data = data
|
1352
|
+
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
1353
|
+
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
1354
|
+
if covariates is not None:
|
1355
|
+
if isinstance(covariates, str):
|
1356
|
+
covariates = [covariates]
|
1357
|
+
partial_covariate_names = [
|
1358
|
+
covariate_name
|
1359
|
+
for covariate_name in covariate_names
|
1360
|
+
if any(covariate in covariate_name for covariate in covariates)
|
1361
|
+
]
|
1362
|
+
covariate_names = partial_covariate_names
|
1363
|
+
covariate_names_non_zero = [
|
1364
|
+
covariate_name
|
1365
|
+
for covariate_name in covariate_names
|
1366
|
+
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
1367
|
+
]
|
1368
|
+
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
1369
|
+
if not plot_zero_covariate:
|
1370
|
+
covariate_names = covariate_names_non_zero
|
1371
|
+
|
1372
|
+
# set up df for plotting
|
1373
|
+
plot_df = pd.concat(
|
1374
|
+
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
1375
|
+
axis=1,
|
1376
|
+
)
|
1377
|
+
plot_df.columns = covariate_names
|
1378
|
+
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
1379
|
+
|
1380
|
+
plot_df = plot_df.reset_index()
|
1381
|
+
|
1382
|
+
if len(covariate_names_zero) != 0:
|
1383
|
+
if plot_facets:
|
1384
|
+
if plot_zero_covariate and not plot_zero_cell_type:
|
1385
|
+
plot_df = plot_df[plot_df["value"] != 0]
|
1386
|
+
for covariate_name_zero in covariate_names_zero:
|
1387
|
+
new_row = {
|
1388
|
+
"Covariate": covariate_name_zero,
|
1389
|
+
"Cell Type": "zero",
|
1390
|
+
"value": 0,
|
1391
|
+
}
|
1392
|
+
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
1393
|
+
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1394
|
+
plot_df = plot_df.sort_values(["covariate_"])
|
1395
|
+
if not plot_zero_cell_type:
|
1396
|
+
cell_type_names_zero = [
|
1397
|
+
name
|
1398
|
+
for name in plot_df["Cell Type"].unique()
|
1399
|
+
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
1400
|
+
]
|
1401
|
+
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
1402
|
+
|
1403
|
+
# If plot as facets, create a FacetGrid and map barplot to it.
|
1404
|
+
if plot_facets:
|
1405
|
+
if isinstance(palette, ListedColormap):
|
1406
|
+
palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
|
1407
|
+
if figsize is not None:
|
1408
|
+
height = figsize[0]
|
1409
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1410
|
+
else:
|
1411
|
+
height = 3
|
1412
|
+
aspect = 2
|
1413
|
+
|
1414
|
+
g = sns.FacetGrid(
|
1415
|
+
plot_df,
|
1416
|
+
col="Covariate",
|
1417
|
+
sharey=True,
|
1418
|
+
sharex=False,
|
1419
|
+
height=height,
|
1420
|
+
aspect=aspect,
|
1421
|
+
)
|
1422
|
+
|
1423
|
+
g.map(
|
1424
|
+
sns.barplot,
|
1425
|
+
"Cell Type",
|
1426
|
+
"value",
|
1427
|
+
palette=palette,
|
1428
|
+
order=level_order,
|
1429
|
+
**args_barplot,
|
1430
|
+
)
|
1431
|
+
g.set_xticklabels(rotation=90)
|
1432
|
+
g.set(ylabel=parameter)
|
1433
|
+
axes = g.axes.flatten()
|
1434
|
+
for i, ax in enumerate(axes):
|
1435
|
+
ax.set_title(covariate_names[i])
|
1436
|
+
if len(ax.get_xticklabels()) < 5:
|
1437
|
+
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
1438
|
+
if len(ax.get_xticklabels()) == 1:
|
1439
|
+
if ax.get_xticklabels()[0]._text == "zero":
|
1440
|
+
ax.set_xticks([])
|
1441
|
+
|
1442
|
+
if save:
|
1443
|
+
plt.savefig(save, bbox_inches="tight")
|
1444
|
+
if show:
|
1445
|
+
plt.show()
|
1446
|
+
if return_fig:
|
1447
|
+
return plt.gcf()
|
1448
|
+
if not (show or save):
|
1449
|
+
return g
|
1450
|
+
return None
|
1451
|
+
|
1452
|
+
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
1453
|
+
else:
|
1454
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1455
|
+
if len(covariate_names) == 1:
|
1456
|
+
if isinstance(palette, ListedColormap):
|
1457
|
+
palette = np.array(
|
1458
|
+
[palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
|
1459
|
+
).tolist()
|
1460
|
+
sns.barplot(
|
1461
|
+
data=plot_df,
|
1462
|
+
x="Cell Type",
|
1463
|
+
y="value",
|
1464
|
+
hue="x",
|
1465
|
+
palette=palette,
|
1466
|
+
ax=ax,
|
1467
|
+
)
|
1468
|
+
ax.set_title(covariate_names[0])
|
1469
|
+
else:
|
1470
|
+
if isinstance(palette, ListedColormap):
|
1471
|
+
palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
|
1472
|
+
sns.barplot(
|
1473
|
+
data=plot_df,
|
1474
|
+
x="Cell Type",
|
1475
|
+
y="value",
|
1476
|
+
hue="Covariate",
|
1477
|
+
palette=palette,
|
1478
|
+
ax=ax,
|
1479
|
+
)
|
1480
|
+
cell_types = pd.unique(plot_df["Cell Type"])
|
1481
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1482
|
+
|
1483
|
+
if save:
|
1484
|
+
plt.savefig(save, bbox_inches="tight")
|
1485
|
+
if show:
|
1486
|
+
plt.show()
|
1487
|
+
if return_fig:
|
1488
|
+
return plt.gcf()
|
1489
|
+
if not (show or save):
|
1490
|
+
return ax
|
1491
|
+
return None
|
1492
|
+
|
1493
|
+
def plot_boxplots( # pragma: no cover
|
1494
|
+
self,
|
1495
|
+
data: AnnData | MuData,
|
1496
|
+
feature_name: str,
|
1497
|
+
modality_key: str = "coda",
|
1498
|
+
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
1499
|
+
plot_facets: bool = False,
|
1500
|
+
add_dots: bool = False,
|
1501
|
+
cell_types: list | None = None,
|
1502
|
+
args_boxplot: dict | None = None,
|
1503
|
+
args_swarmplot: dict | None = None,
|
1504
|
+
palette: str | None = "Blues",
|
1505
|
+
show_legend: bool | None = True,
|
1506
|
+
level_order: list[str] = None,
|
1507
|
+
figsize: tuple[float, float] | None = None,
|
1508
|
+
dpi: int | None = 100,
|
1509
|
+
return_fig: bool | None = None,
|
1510
|
+
ax: plt.Axes | None = None,
|
1511
|
+
show: bool | None = None,
|
1512
|
+
save: str | bool | None = None,
|
1513
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1514
|
+
"""Grouped boxplot visualization.
|
1515
|
+
|
1516
|
+
The cell counts for each cell type are shown as a group of boxplots
|
1517
|
+
with intra--group separation by a covariate from data.obs.
|
1518
|
+
|
1519
|
+
Args:
|
1520
|
+
data: AnnData object or MuData object
|
1521
|
+
feature_name: The name of the feature in data.obs to plot
|
1522
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1523
|
+
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
1524
|
+
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
1525
|
+
Defaults to "relative".
|
1526
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
|
1527
|
+
add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
|
1528
|
+
cell_types: Subset of cell types that should be plotted. Defaults to None.
|
1529
|
+
args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
|
1530
|
+
args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
|
1531
|
+
figsize: Figure size. Defaults to None.
|
1532
|
+
dpi: Dpi setting. Defaults to 100.
|
1533
|
+
palette: The seaborn color map for the barplot. Defaults to "Blues".
|
1534
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1535
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1536
|
+
|
1537
|
+
Returns:
|
1538
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1539
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1540
|
+
|
1541
|
+
Examples:
|
1542
|
+
>>> import pertpy as pt
|
1543
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1544
|
+
>>> sccoda = pt.tl.Sccoda()
|
1545
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1546
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1547
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
1548
|
+
|
1549
|
+
Preview:
|
1550
|
+
.. image:: /_static/docstring_previews/sccoda_boxplots.png
|
1551
|
+
"""
|
1552
|
+
if args_boxplot is None:
|
1553
|
+
args_boxplot = {}
|
1554
|
+
if args_swarmplot is None:
|
1555
|
+
args_swarmplot = {}
|
1556
|
+
if isinstance(data, MuData):
|
1557
|
+
data = data[modality_key]
|
1558
|
+
if isinstance(data, AnnData):
|
1559
|
+
data = data
|
1560
|
+
# y scale transformations
|
1561
|
+
if y_scale == "relative":
|
1562
|
+
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
1563
|
+
X = data.X / sample_sums
|
1564
|
+
value_name = "Proportion"
|
1565
|
+
# add pseudocount 0.5 if using log scale
|
1566
|
+
elif y_scale == "log":
|
1567
|
+
X = data.X.copy()
|
1568
|
+
X[X == 0] = 0.5
|
1569
|
+
X = np.log(X)
|
1570
|
+
value_name = "log(count)"
|
1571
|
+
elif y_scale == "log10":
|
1572
|
+
X = data.X.copy()
|
1573
|
+
X[X == 0] = 0.5
|
1574
|
+
X = np.log(X)
|
1575
|
+
value_name = "log10(count)"
|
1576
|
+
elif y_scale == "count":
|
1577
|
+
X = data.X
|
1578
|
+
value_name = "count"
|
1579
|
+
else:
|
1580
|
+
raise ValueError("Invalid y_scale transformation")
|
1581
|
+
|
1582
|
+
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
1583
|
+
data.obs[feature_name], left_index=True, right_index=True
|
1584
|
+
)
|
1585
|
+
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
1586
|
+
if cell_types is not None:
|
1587
|
+
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
1588
|
+
|
1589
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1590
|
+
# We had to drop the dependency.
|
1591
|
+
# Get credible effects results from model
|
1592
|
+
# if draw_effects:
|
1593
|
+
# if model is not None:
|
1594
|
+
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
1595
|
+
# else:
|
1596
|
+
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
1597
|
+
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
1598
|
+
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
1599
|
+
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
1600
|
+
|
1601
|
+
# If plot as facets, create a FacetGrid and map boxplot to it.
|
1602
|
+
if plot_facets:
|
1603
|
+
if level_order is None:
|
1604
|
+
level_order = pd.unique(plot_df[feature_name])
|
1605
|
+
|
1606
|
+
K = X.shape[1]
|
1607
|
+
|
1608
|
+
if figsize is not None:
|
1609
|
+
height = figsize[0]
|
1610
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1611
|
+
else:
|
1612
|
+
height = 3
|
1613
|
+
aspect = 2
|
1614
|
+
|
1615
|
+
g = sns.FacetGrid(
|
1616
|
+
plot_df,
|
1617
|
+
col="Cell type",
|
1618
|
+
sharey=False,
|
1619
|
+
col_wrap=int(np.floor(np.sqrt(K))),
|
1620
|
+
height=height,
|
1621
|
+
aspect=aspect,
|
1622
|
+
)
|
1623
|
+
g.map(
|
1624
|
+
sns.boxplot,
|
1625
|
+
feature_name,
|
1626
|
+
value_name,
|
1627
|
+
palette=palette,
|
1628
|
+
order=level_order,
|
1629
|
+
**args_boxplot,
|
1630
|
+
)
|
1631
|
+
|
1632
|
+
if add_dots:
|
1633
|
+
if "hue" in args_swarmplot:
|
1634
|
+
hue = args_swarmplot.pop("hue")
|
1635
|
+
else:
|
1636
|
+
hue = None
|
1637
|
+
|
1638
|
+
if hue is None:
|
1639
|
+
g.map(
|
1640
|
+
sns.swarmplot,
|
1641
|
+
feature_name,
|
1642
|
+
value_name,
|
1643
|
+
color="black",
|
1644
|
+
order=level_order,
|
1645
|
+
**args_swarmplot,
|
1646
|
+
).set_titles("{col_name}")
|
1647
|
+
else:
|
1648
|
+
g.map(
|
1649
|
+
sns.swarmplot,
|
1650
|
+
feature_name,
|
1651
|
+
value_name,
|
1652
|
+
hue,
|
1653
|
+
order=level_order,
|
1654
|
+
**args_swarmplot,
|
1655
|
+
).set_titles("{col_name}")
|
1656
|
+
|
1657
|
+
if save:
|
1658
|
+
plt.savefig(save, bbox_inches="tight")
|
1659
|
+
if show:
|
1660
|
+
plt.show()
|
1661
|
+
if return_fig:
|
1662
|
+
return plt.gcf()
|
1663
|
+
if not (show or save):
|
1664
|
+
return g
|
1665
|
+
return None
|
1666
|
+
|
1667
|
+
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
1668
|
+
else:
|
1669
|
+
if level_order:
|
1670
|
+
args_boxplot["hue_order"] = level_order
|
1671
|
+
args_swarmplot["hue_order"] = level_order
|
1672
|
+
|
1673
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1674
|
+
|
1675
|
+
ax = sns.boxplot(
|
1676
|
+
x="Cell type",
|
1677
|
+
y=value_name,
|
1678
|
+
hue=feature_name,
|
1679
|
+
data=plot_df,
|
1680
|
+
fliersize=1,
|
1681
|
+
palette=palette,
|
1682
|
+
ax=ax,
|
1683
|
+
**args_boxplot,
|
1684
|
+
)
|
1685
|
+
|
1686
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1687
|
+
# We had to drop the dependency.
|
1688
|
+
# if draw_effects:
|
1689
|
+
# pairs = [
|
1690
|
+
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
1691
|
+
# for _, row in credible_effects_df.iterrows()
|
1692
|
+
# ]
|
1693
|
+
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
1694
|
+
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
1695
|
+
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
1696
|
+
# annot.annotate()
|
1697
|
+
|
1698
|
+
if add_dots:
|
1699
|
+
sns.swarmplot(
|
1700
|
+
x="Cell type",
|
1701
|
+
y=value_name,
|
1702
|
+
data=plot_df,
|
1703
|
+
hue=feature_name,
|
1704
|
+
ax=ax,
|
1705
|
+
dodge=True,
|
1706
|
+
palette="dark:black",
|
1707
|
+
**args_swarmplot,
|
1708
|
+
)
|
1709
|
+
|
1710
|
+
cell_types = pd.unique(plot_df["Cell type"])
|
1711
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1712
|
+
|
1713
|
+
if show_legend:
|
1714
|
+
handles, labels = ax.get_legend_handles_labels()
|
1715
|
+
handout = []
|
1716
|
+
labelout = []
|
1717
|
+
for h, l in zip(handles, labels, strict=False):
|
1718
|
+
if l not in labelout:
|
1719
|
+
labelout.append(l)
|
1720
|
+
handout.append(h)
|
1721
|
+
ax.legend(
|
1722
|
+
handout,
|
1723
|
+
labelout,
|
1724
|
+
loc="upper left",
|
1725
|
+
bbox_to_anchor=(1, 1),
|
1726
|
+
ncol=1,
|
1727
|
+
title=feature_name,
|
1728
|
+
)
|
1729
|
+
|
1730
|
+
if save:
|
1731
|
+
plt.savefig(save, bbox_inches="tight")
|
1732
|
+
if show:
|
1733
|
+
plt.show()
|
1734
|
+
if return_fig:
|
1735
|
+
return plt.gcf()
|
1736
|
+
if not (show or save):
|
1737
|
+
return ax
|
1738
|
+
return None
|
1739
|
+
|
1740
|
+
def plot_rel_abundance_dispersion_plot( # pragma: no cover
|
1741
|
+
self,
|
1742
|
+
data: AnnData | MuData,
|
1743
|
+
modality_key: str = "coda",
|
1744
|
+
abundant_threshold: float | None = 0.9,
|
1745
|
+
default_color: str | None = "Grey",
|
1746
|
+
abundant_color: str | None = "Red",
|
1747
|
+
label_cell_types: bool = True,
|
1748
|
+
figsize: tuple[float, float] | None = None,
|
1749
|
+
dpi: int | None = 100,
|
1750
|
+
return_fig: bool | None = None,
|
1751
|
+
ax: plt.Axes | None = None,
|
1752
|
+
show: bool | None = None,
|
1753
|
+
save: str | bool | None = None,
|
1754
|
+
) -> plt.Axes | plt.Figure | None:
|
1755
|
+
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
1756
|
+
|
1757
|
+
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
1758
|
+
|
1759
|
+
Args:
|
1760
|
+
data: AnnData or MuData object.
|
1761
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1762
|
+
Defaults to "coda".
|
1763
|
+
abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
|
1764
|
+
default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
|
1765
|
+
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
1766
|
+
Defaults to "Red".
|
1767
|
+
label_cell_types: Label dots with cell type names. Defaults to True.
|
1768
|
+
figsize: Figure size. Defaults to None.
|
1769
|
+
dpi: Dpi setting. Defaults to 100.
|
1770
|
+
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
1771
|
+
|
1772
|
+
Returns:
|
1773
|
+
A :class:`~matplotlib.axes.Axes` object
|
1774
|
+
|
1775
|
+
Examples:
|
1776
|
+
>>> import pertpy as pt
|
1777
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1778
|
+
>>> sccoda = pt.tl.Sccoda()
|
1779
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1780
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1781
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1782
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1783
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
1784
|
+
|
1785
|
+
Preview:
|
1786
|
+
.. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
|
1787
|
+
"""
|
1788
|
+
if isinstance(data, MuData):
|
1789
|
+
data = data[modality_key]
|
1790
|
+
if isinstance(data, AnnData):
|
1791
|
+
data = data
|
1792
|
+
if ax is None:
|
1793
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1794
|
+
|
1795
|
+
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
1796
|
+
|
1797
|
+
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
1798
|
+
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
1799
|
+
|
1800
|
+
# select reference
|
1801
|
+
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
1802
|
+
|
1803
|
+
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
1804
|
+
|
1805
|
+
# Scatterplot
|
1806
|
+
plot_df = pd.DataFrame(
|
1807
|
+
{
|
1808
|
+
"Total dispersion": cell_type_disp,
|
1809
|
+
"Cell type": data.var.index,
|
1810
|
+
"Presence": 1 - percent_zero,
|
1811
|
+
"Is abundant": is_abundant,
|
1812
|
+
}
|
1813
|
+
)
|
1814
|
+
|
1815
|
+
if len(np.unique(plot_df["Is abundant"])) > 1:
|
1816
|
+
palette = [default_color, abundant_color]
|
1817
|
+
elif np.unique(plot_df["Is abundant"]) == [False]:
|
1818
|
+
palette = [default_color]
|
1819
|
+
else:
|
1820
|
+
palette = [abundant_color]
|
1821
|
+
|
1822
|
+
ax = sns.scatterplot(
|
1823
|
+
data=plot_df,
|
1824
|
+
x="Presence",
|
1825
|
+
y="Total dispersion",
|
1826
|
+
hue="Is abundant",
|
1827
|
+
palette=palette,
|
1828
|
+
ax=ax,
|
1829
|
+
)
|
1830
|
+
|
1831
|
+
# Text labels for abundant cell types
|
1832
|
+
|
1833
|
+
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
1834
|
+
|
1835
|
+
def label_point(x, y, val, ax):
|
1836
|
+
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
1837
|
+
texts = [
|
1838
|
+
ax.text(
|
1839
|
+
point["x"],
|
1840
|
+
point["y"],
|
1841
|
+
str(point["val"]),
|
1842
|
+
)
|
1843
|
+
for i, point in a.iterrows()
|
1844
|
+
]
|
1845
|
+
adjust_text(texts)
|
1846
|
+
|
1847
|
+
if label_cell_types:
|
1848
|
+
label_point(
|
1849
|
+
abundant_df["Presence"],
|
1850
|
+
abundant_df["Total dispersion"],
|
1851
|
+
abundant_df["Cell type"],
|
1852
|
+
plt.gca(),
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1856
|
+
|
1857
|
+
if save:
|
1858
|
+
plt.savefig(save, bbox_inches="tight")
|
1859
|
+
if show:
|
1860
|
+
plt.show()
|
1861
|
+
if return_fig:
|
1862
|
+
return plt.gcf()
|
1863
|
+
if not (show or save):
|
1864
|
+
return ax
|
1865
|
+
return None
|
1866
|
+
|
1867
|
+
def plot_draw_tree( # pragma: no cover
|
1868
|
+
self,
|
1869
|
+
data: AnnData | MuData,
|
1870
|
+
modality_key: str = "coda",
|
1871
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1872
|
+
tight_text: bool | None = False,
|
1873
|
+
show_scale: bool | None = False,
|
1874
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1875
|
+
figsize: tuple[float, float] | None = (None, None),
|
1876
|
+
dpi: int | None = 100,
|
1877
|
+
show: bool | None = True,
|
1878
|
+
save: str | bool | None = None,
|
1879
|
+
) -> Tree | None:
|
1880
|
+
"""Plot a tree using input ete3 tree object.
|
1881
|
+
|
1882
|
+
Args:
|
1883
|
+
data: AnnData object or MuData object.
|
1884
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1885
|
+
Defaults to "coda".
|
1886
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1887
|
+
Defaults to "tree".
|
1888
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1889
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1890
|
+
Default to False.
|
1891
|
+
show_scale: Include the scale legend in the tree image or not.
|
1892
|
+
Defaults to False.
|
1893
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1894
|
+
Defaults to True.
|
1895
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
1896
|
+
Output image can be saved whether show is True or not.
|
1897
|
+
Defaults to None.
|
1898
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
1899
|
+
figsize: Figure size. Defaults to None.
|
1900
|
+
dpi: Dots per inches. Defaults to 100.
|
1901
|
+
|
1902
|
+
Returns:
|
1903
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
1904
|
+
|
1905
|
+
Examples:
|
1906
|
+
>>> import pertpy as pt
|
1907
|
+
>>> adata = pt.dt.tasccoda_example()
|
1908
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1909
|
+
>>> mdata = tasccoda.load(
|
1910
|
+
>>> adata, type="sample_level",
|
1911
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1912
|
+
>>> key_added="lineage", add_level_name=True
|
1913
|
+
>>> )
|
1914
|
+
>>> mdata = tasccoda.prepare(
|
1915
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1916
|
+
>>> )
|
1917
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1918
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
1919
|
+
|
1920
|
+
Preview:
|
1921
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_tree.png
|
1922
|
+
"""
|
1923
|
+
try:
|
1924
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
1925
|
+
except ImportError:
|
1926
|
+
raise ImportError(
|
1927
|
+
"To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
|
1928
|
+
) from None
|
1929
|
+
|
1930
|
+
if isinstance(data, MuData):
|
1931
|
+
data = data[modality_key]
|
1932
|
+
if isinstance(data, AnnData):
|
1933
|
+
data = data
|
1934
|
+
if isinstance(tree, str):
|
1935
|
+
tree = data.uns[tree]
|
1936
|
+
|
1937
|
+
def my_layout(node):
|
1938
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
1939
|
+
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
1940
|
+
|
1941
|
+
tree_style = TreeStyle()
|
1942
|
+
tree_style.show_leaf_name = False
|
1943
|
+
tree_style.layout_fn = my_layout
|
1944
|
+
tree_style.show_scale = show_scale
|
1945
|
+
|
1946
|
+
if save is not None:
|
1947
|
+
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1948
|
+
if show:
|
1949
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1950
|
+
else:
|
1951
|
+
return tree, tree_style
|
1952
|
+
|
1953
|
+
def plot_draw_effects( # pragma: no cover
|
1954
|
+
self,
|
1955
|
+
data: AnnData | MuData,
|
1956
|
+
covariate: str,
|
1957
|
+
modality_key: str = "coda",
|
1958
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1959
|
+
show_legend: bool | None = None,
|
1960
|
+
show_leaf_effects: bool | None = False,
|
1961
|
+
tight_text: bool | None = False,
|
1962
|
+
show_scale: bool | None = False,
|
1963
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1964
|
+
figsize: tuple[float, float] | None = (None, None),
|
1965
|
+
dpi: int | None = 100,
|
1966
|
+
show: bool | None = True,
|
1967
|
+
save: str | None = None,
|
1968
|
+
) -> Tree | None:
|
1969
|
+
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
1970
|
+
|
1971
|
+
Args:
|
1972
|
+
data: AnnData object or MuData object.
|
1973
|
+
covariate: The covariate, whose effects should be plotted.
|
1974
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1975
|
+
Defaults to "coda".
|
1976
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1977
|
+
Defaults to "tree".
|
1978
|
+
show_legend: If show legend of nodes significant effects or not.
|
1979
|
+
Defaults to False if show_leaf_effects is True.
|
1980
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
1981
|
+
Defaults to False.
|
1982
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1983
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1984
|
+
Defaults to False.
|
1985
|
+
show_scale: Include the scale legend in the tree image or not. Defaults to False.
|
1986
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
1987
|
+
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
1988
|
+
Defaults to None.
|
1989
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
1990
|
+
figsize: Figure size. Defaults to None.
|
1991
|
+
dpi: Dots per inches. Defaults to 100.
|
1992
|
+
|
1993
|
+
Returns:
|
1994
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
1995
|
+
or plot the tree inline (`show = False`)
|
1996
|
+
|
1997
|
+
Examples:
|
1998
|
+
>>> import pertpy as pt
|
1999
|
+
>>> adata = pt.dt.tasccoda_example()
|
2000
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
2001
|
+
>>> mdata = tasccoda.load(
|
2002
|
+
>>> adata, type="sample_level",
|
2003
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
2004
|
+
>>> key_added="lineage", add_level_name=True
|
2005
|
+
>>> )
|
2006
|
+
>>> mdata = tasccoda.prepare(
|
2007
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
2008
|
+
>>> )
|
2009
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
2010
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
2011
|
+
|
2012
|
+
Preview:
|
2013
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_effects.png
|
2014
|
+
"""
|
2015
|
+
try:
|
2016
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
2017
|
+
except ImportError:
|
2018
|
+
raise ImportError(
|
2019
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2020
|
+
) from None
|
2021
|
+
|
2022
|
+
if isinstance(data, MuData):
|
2023
|
+
data = data[modality_key]
|
2024
|
+
if isinstance(data, AnnData):
|
2025
|
+
data = data
|
2026
|
+
if show_legend is None:
|
2027
|
+
show_legend = not show_leaf_effects
|
2028
|
+
elif show_legend:
|
2029
|
+
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
2030
|
+
|
2031
|
+
if isinstance(tree, str):
|
2032
|
+
tree = data.uns[tree]
|
2033
|
+
# Collapse tree singularities
|
2034
|
+
tree2 = collapse_singularities_2(tree)
|
2035
|
+
|
2036
|
+
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
2037
|
+
node_effs.index = node_effs.index.get_level_values("Node")
|
2038
|
+
|
2039
|
+
covariates = data.uns["scCODA_params"]["covariate_names"]
|
2040
|
+
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
2041
|
+
eff_df = pd.concat(effect_dfs)
|
2042
|
+
eff_df.index = pd.MultiIndex.from_product(
|
2043
|
+
(covariates, data.var.index.tolist()),
|
2044
|
+
names=["Covariate", "Cell Type"],
|
2045
|
+
)
|
2046
|
+
leaf_effs = eff_df.loc[(covariate,),].copy()
|
2047
|
+
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
2048
|
+
|
2049
|
+
# Add effect values
|
2050
|
+
for n in tree2.traverse():
|
2051
|
+
nstyle = NodeStyle()
|
2052
|
+
nstyle["size"] = 0
|
2053
|
+
n.set_style(nstyle)
|
2054
|
+
if n.name in node_effs.index:
|
2055
|
+
e = node_effs.loc[n.name, "Final Parameter"]
|
2056
|
+
n.add_feature("node_effect", e)
|
2057
|
+
else:
|
2058
|
+
n.add_feature("node_effect", 0)
|
2059
|
+
if n.name in leaf_effs.index:
|
2060
|
+
e = leaf_effs.loc[n.name, "Effect"]
|
2061
|
+
n.add_feature("leaf_effect", e)
|
2062
|
+
else:
|
2063
|
+
n.add_feature("leaf_effect", 0)
|
2064
|
+
|
2065
|
+
# Scale effect values to get nice node sizes
|
2066
|
+
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
2067
|
+
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
2068
|
+
|
2069
|
+
def my_layout(node):
|
2070
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
2071
|
+
text_face.margin_left = 10
|
2072
|
+
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
2073
|
+
|
2074
|
+
# if node.is_leaf():
|
2075
|
+
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
2076
|
+
if np.sign(node.node_effect) == 1:
|
2077
|
+
color = "blue"
|
2078
|
+
elif np.sign(node.node_effect) == -1:
|
2079
|
+
color = "red"
|
2080
|
+
else:
|
2081
|
+
color = "cyan"
|
2082
|
+
if size != 0:
|
2083
|
+
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
2084
|
+
|
2085
|
+
tree_style = TreeStyle()
|
2086
|
+
tree_style.show_leaf_name = False
|
2087
|
+
tree_style.layout_fn = my_layout
|
2088
|
+
tree_style.show_scale = show_scale
|
2089
|
+
tree_style.draw_guiding_lines = True
|
2090
|
+
tree_style.legend_position = 1
|
2091
|
+
|
2092
|
+
if show_legend:
|
2093
|
+
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
2094
|
+
tree_style.legend.add_face(TextFace(" "), column=1)
|
2095
|
+
for i in range(4, 0, -1):
|
2096
|
+
tree_style.legend.add_face(
|
2097
|
+
CircleFace(
|
2098
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2099
|
+
"red",
|
2100
|
+
),
|
2101
|
+
column=0,
|
2102
|
+
)
|
2103
|
+
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
2104
|
+
tree_style.legend.add_face(
|
2105
|
+
CircleFace(
|
2106
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2107
|
+
"blue",
|
2108
|
+
),
|
2109
|
+
column=1,
|
2110
|
+
)
|
2111
|
+
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
2112
|
+
|
2113
|
+
if show_leaf_effects:
|
2114
|
+
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
2115
|
+
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
2116
|
+
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
2117
|
+
|
2118
|
+
dir_path = Path.cwd()
|
2119
|
+
dir_path = Path(dir_path / "tree_effect.png")
|
2120
|
+
tree2.render(dir_path, tree_style=tree_style, units="in")
|
2121
|
+
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
2122
|
+
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
2123
|
+
img = mpimg.imread(dir_path)
|
2124
|
+
ax[0].imshow(img)
|
2125
|
+
ax[0].get_xaxis().set_visible(False)
|
2126
|
+
ax[0].get_yaxis().set_visible(False)
|
2127
|
+
ax[0].set_frame_on(False)
|
2128
|
+
|
2129
|
+
ax[1].get_yaxis().set_visible(False)
|
2130
|
+
ax[1].spines["left"].set_visible(False)
|
2131
|
+
ax[1].spines["right"].set_visible(False)
|
2132
|
+
ax[1].spines["top"].set_visible(False)
|
2133
|
+
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
2134
|
+
plt.subplots_adjust(wspace=0)
|
2135
|
+
|
2136
|
+
if save is not None:
|
2137
|
+
plt.savefig(save)
|
2138
|
+
|
2139
|
+
if save is not None and not show_leaf_effects:
|
2140
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2141
|
+
if show:
|
2142
|
+
if not show_leaf_effects:
|
2143
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2144
|
+
else:
|
2145
|
+
if not show_leaf_effects:
|
2146
|
+
return tree2, tree_style
|
2147
|
+
return None
|
2148
|
+
|
2149
|
+
def plot_effects_umap( # pragma: no cover
|
2150
|
+
self,
|
2151
|
+
mdata: MuData,
|
2152
|
+
effect_name: str | list | None,
|
2153
|
+
cluster_key: str,
|
2154
|
+
modality_key_1: str = "rna",
|
2155
|
+
modality_key_2: str = "coda",
|
2156
|
+
color_map: Colormap | str | None = None,
|
2157
|
+
palette: str | Sequence[str] | None = None,
|
2158
|
+
return_fig: bool | None = None,
|
2159
|
+
ax: Axes = None,
|
2160
|
+
show: bool = None,
|
2161
|
+
save: str | bool | None = None,
|
2162
|
+
**kwargs,
|
2163
|
+
) -> plt.Axes | plt.Figure | None:
|
2164
|
+
"""Plot a UMAP visualization colored by effect strength.
|
2165
|
+
|
2166
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
2167
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
2168
|
+
|
2169
|
+
Args:
|
2170
|
+
mudata: MuData object.
|
2171
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
2172
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
2173
|
+
To assign cell types' effects to original cells.
|
2174
|
+
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
2175
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
2176
|
+
Defaults to "coda".
|
2177
|
+
show: Whether to display the figure or return axis. Defaults to None.
|
2178
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
2179
|
+
Defaults to None.
|
2180
|
+
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
2181
|
+
|
2182
|
+
Returns:
|
2183
|
+
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
2184
|
+
|
2185
|
+
Examples:
|
2186
|
+
>>> import pertpy as pt
|
2187
|
+
>>> import scanpy as sc
|
2188
|
+
>>> import schist
|
2189
|
+
>>> adata = pt.dt.haber_2017_regions()
|
2190
|
+
>>> sc.pp.neighbors(adata)
|
2191
|
+
>>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
|
2192
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
2193
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
2194
|
+
>>> cell_type_identifier="nsbm_level_1",
|
2195
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
2196
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
2197
|
+
>>> add_level_name=True)
|
2198
|
+
>>> tasccoda_model.prepare(
|
2199
|
+
>>> tasccoda_data,
|
2200
|
+
>>> modality_key="coda",
|
2201
|
+
>>> reference_cell_type="18",
|
2202
|
+
>>> formula="condition",
|
2203
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
2204
|
+
>>> tree_key="tree"
|
2205
|
+
>>> )
|
2206
|
+
>>> tasccoda_model.run_nuts(
|
2207
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2208
|
+
... )
|
2209
|
+
>>> tasccoda_model.run_nuts(
|
2210
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2211
|
+
... )
|
2212
|
+
>>> sc.tl.umap(tasccoda_data["rna"])
|
2213
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
2214
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
2215
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
2216
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
2217
|
+
>>> cluster_key="nsbm_level_1",
|
2218
|
+
>>> )
|
2219
|
+
|
2220
|
+
Preview:
|
2221
|
+
.. image:: /_static/docstring_previews/tasccoda_effects_umap.png
|
2222
|
+
"""
|
2223
|
+
# TODO: Add effect_name parameter and cluster_key and test the example
|
2224
|
+
data_rna = mdata[modality_key_1]
|
2225
|
+
data_coda = mdata[modality_key_2]
|
2226
|
+
if isinstance(effect_name, str):
|
2227
|
+
effect_name = [effect_name]
|
2228
|
+
for _, effect in enumerate(effect_name):
|
2229
|
+
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
2230
|
+
if kwargs.get("vmin"):
|
2231
|
+
vmin = kwargs["vmin"]
|
2232
|
+
kwargs.pop("vmin")
|
2233
|
+
else:
|
2234
|
+
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
2235
|
+
if kwargs.get("vmax"):
|
2236
|
+
vmax = kwargs["vmax"]
|
2237
|
+
kwargs.pop("vmax")
|
2238
|
+
else:
|
2239
|
+
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
2240
|
+
|
2241
|
+
return sc.pl.umap(
|
2242
|
+
data_rna,
|
2243
|
+
color=effect_name,
|
2244
|
+
vmax=vmax,
|
2245
|
+
vmin=vmin,
|
2246
|
+
palette=palette,
|
2247
|
+
color_map=color_map,
|
2248
|
+
return_fig=return_fig,
|
2249
|
+
ax=ax,
|
2250
|
+
show=show,
|
2251
|
+
save=save,
|
2252
|
+
**kwargs,
|
2253
|
+
)
|
2254
|
+
|
1116
2255
|
|
1117
2256
|
def get_a(
|
1118
2257
|
tree: tt.tree,
|
@@ -1242,7 +2381,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
|
|
1242
2381
|
|
1243
2382
|
|
1244
2383
|
def get_a_2(
|
1245
|
-
tree:
|
2384
|
+
tree: Tree,
|
1246
2385
|
leaf_order: list[str] = None,
|
1247
2386
|
node_order: list[str] = None,
|
1248
2387
|
) -> tuple[np.ndarray, int]:
|
@@ -1263,6 +2402,13 @@ def get_a_2(
|
|
1263
2402
|
T
|
1264
2403
|
number of nodes in the tree, excluding the root node
|
1265
2404
|
"""
|
2405
|
+
try:
|
2406
|
+
import ete3 as ete
|
2407
|
+
except ImportError:
|
2408
|
+
raise ImportError(
|
2409
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2410
|
+
) from None
|
2411
|
+
|
1266
2412
|
n_tips = len(tree.get_leaves())
|
1267
2413
|
n_nodes = len(tree.get_descendants())
|
1268
2414
|
|
@@ -1292,7 +2438,7 @@ def get_a_2(
|
|
1292
2438
|
return A_, n_nodes
|
1293
2439
|
|
1294
2440
|
|
1295
|
-
def collapse_singularities_2(tree:
|
2441
|
+
def collapse_singularities_2(tree: Tree) -> Tree:
|
1296
2442
|
"""Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
|
1297
2443
|
|
1298
2444
|
Args:
|
@@ -1327,10 +2473,10 @@ def linkage_to_newick(
|
|
1327
2473
|
|
1328
2474
|
def build_newick(node, newick, parentdist, leaf_names):
|
1329
2475
|
if node.is_leaf():
|
1330
|
-
return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
|
2476
|
+
return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
|
1331
2477
|
else:
|
1332
2478
|
if len(newick) > 0:
|
1333
|
-
newick = f"):{(parentdist - node.dist)/2}{newick}"
|
2479
|
+
newick = f"):{(parentdist - node.dist) / 2}{newick}"
|
1334
2480
|
else:
|
1335
2481
|
newick = ");"
|
1336
2482
|
newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
|
@@ -1363,14 +2509,15 @@ def import_tree(
|
|
1363
2509
|
|
1364
2510
|
Args:
|
1365
2511
|
data: A tascCODA-compatible data object.
|
1366
|
-
modality_1: If `data` is MuData,
|
1367
|
-
modality_2: If `data` is MuData,
|
2512
|
+
modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object. Defaults to None.
|
2513
|
+
modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object. Defaults to None.
|
1368
2514
|
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
|
1369
2515
|
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
|
1370
2516
|
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
|
1371
|
-
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
1372
|
-
|
1373
|
-
|
2517
|
+
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
2518
|
+
Defaults to True.
|
2519
|
+
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
2520
|
+
If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
|
1374
2521
|
|
1375
2522
|
Returns:
|
1376
2523
|
Updates data with the following:
|
@@ -1379,6 +2526,13 @@ def import_tree(
|
|
1379
2526
|
|
1380
2527
|
tree: A ete3 tree object.
|
1381
2528
|
"""
|
2529
|
+
try:
|
2530
|
+
import ete3 as ete
|
2531
|
+
except ImportError:
|
2532
|
+
raise ImportError(
|
2533
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2534
|
+
) from None
|
2535
|
+
|
1382
2536
|
if isinstance(data, MuData):
|
1383
2537
|
try:
|
1384
2538
|
data_1 = data[modality_1]
|
@@ -1443,16 +2597,17 @@ def from_scanpy(
|
|
1443
2597
|
|
1444
2598
|
The anndata object needs to have a column in adata.obs that contains the cell type assignment.
|
1445
2599
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
1446
|
-
Further covariates (e.g. subject age) can either be specified via
|
2600
|
+
Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
|
1447
2601
|
|
1448
|
-
NOTE: The order of samples in the returned dataset is determined by the first
|
2602
|
+
NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
|
1449
2603
|
|
1450
2604
|
Args:
|
1451
2605
|
adata: An anndata object from scanpy
|
1452
2606
|
cell_type_identifier: column name in adata.obs that specifies the cell types
|
1453
2607
|
sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
|
1454
2608
|
covariate_uns: key for adata.uns, where covariate values are stored
|
1455
|
-
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2609
|
+
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2610
|
+
Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
|
1456
2611
|
covariate_df: DataFrame with covariates
|
1457
2612
|
|
1458
2613
|
Returns:
|
@@ -1461,50 +2616,40 @@ def from_scanpy(
|
|
1461
2616
|
if isinstance(sample_identifier, str):
|
1462
2617
|
sample_identifier = [sample_identifier]
|
1463
2618
|
|
1464
|
-
if
|
1465
|
-
covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
|
1466
|
-
else:
|
1467
|
-
covariate_obs = sample_identifier # type: ignore
|
1468
|
-
|
1469
|
-
# join sample identifiers
|
1470
|
-
if isinstance(sample_identifier, list):
|
2619
|
+
if len(sample_identifier) > 1:
|
1471
2620
|
adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
|
1472
2621
|
sample_identifier = "scCODA_sample_id"
|
2622
|
+
else:
|
2623
|
+
sample_identifier = sample_identifier[0]
|
1473
2624
|
|
1474
2625
|
# get cell type counts
|
1475
|
-
|
1476
|
-
|
1477
|
-
count_data = count_data.fillna(0)
|
2626
|
+
ct_count_data = pd.crosstab(adata.obs[sample_identifier], adata.obs[cell_type_identifier])
|
2627
|
+
ct_count_data = ct_count_data.fillna(0)
|
1478
2628
|
|
1479
2629
|
# get covariates from different sources
|
1480
|
-
covariate_df_ = pd.DataFrame(index=
|
1481
|
-
|
1482
|
-
if covariate_df is None and covariate_obs is None and covariate_uns is None:
|
1483
|
-
print("No covariate information specified!")
|
2630
|
+
covariate_df_ = pd.DataFrame(index=ct_count_data.index)
|
1484
2631
|
|
1485
2632
|
if covariate_uns is not None:
|
1486
|
-
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
|
1487
|
-
covariate_df_ =
|
2633
|
+
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
|
2634
|
+
covariate_df_ = covariate_df_.join(covariate_df_uns, how="left")
|
1488
2635
|
|
1489
|
-
if covariate_obs
|
1490
|
-
|
1491
|
-
|
1492
|
-
print(f"Covariate {c} has non-unique values! Skipping...")
|
1493
|
-
covariate_obs.remove(c)
|
2636
|
+
if covariate_obs:
|
2637
|
+
is_unique = adata.obs.groupby(sample_identifier, observed=True).transform(lambda x: x.nunique() == 1)
|
2638
|
+
unique_covariates = is_unique.columns[is_unique.all()].tolist()
|
1494
2639
|
|
1495
|
-
|
1496
|
-
|
2640
|
+
if len(unique_covariates) < len(covariate_obs):
|
2641
|
+
skipped = set(covariate_obs) - set(unique_covariates)
|
2642
|
+
print(f"[bold yellow]Covariates {skipped} have non-unique values! Skipping...")
|
2643
|
+
if unique_covariates:
|
2644
|
+
covariate_df_obs = adata.obs.groupby(sample_identifier, observed=True).first()[unique_covariates]
|
2645
|
+
covariate_df_ = covariate_df_.join(covariate_df_obs, how="left")
|
1497
2646
|
|
1498
2647
|
if covariate_df is not None:
|
1499
|
-
if
|
1500
|
-
raise ValueError("
|
1501
|
-
|
1502
|
-
covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
|
1503
|
-
|
1504
|
-
covariate_df_.index = covariate_df_.index.astype(str)
|
2648
|
+
if not covariate_df.index.equals(ct_count_data.index):
|
2649
|
+
raise ValueError("AnnData sample names and covariate_df index do not have the same elements!")
|
2650
|
+
covariate_df_ = covariate_df_.join(covariate_df, how="left")
|
1505
2651
|
|
1506
|
-
|
1507
|
-
var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
|
2652
|
+
var_dat = ct_count_data.sum(axis=0).rename("n_cells").to_frame()
|
1508
2653
|
var_dat.index = var_dat.index.astype(str)
|
1509
2654
|
|
1510
|
-
return AnnData(X=
|
2655
|
+
return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)
|