pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -1,17 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
5
6
|
|
6
7
|
import arviz as az
|
7
|
-
import ete3 as ete
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
11
12
|
import patsy as pt
|
13
|
+
import scanpy as sc
|
14
|
+
import seaborn as sns
|
15
|
+
from adjustText import adjust_text
|
12
16
|
from anndata import AnnData
|
13
|
-
from jax import random
|
14
|
-
from
|
17
|
+
from jax import config, random
|
18
|
+
from matplotlib import cm, rcParams
|
19
|
+
from matplotlib import image as mpimg
|
20
|
+
from matplotlib.colors import ListedColormap
|
15
21
|
from mudata import MuData
|
16
22
|
from numpyro.infer import HMC, MCMC, NUTS, initialization
|
17
23
|
from rich import box, print
|
@@ -20,10 +26,15 @@ from rich.table import Table
|
|
20
26
|
from scipy.cluster import hierarchy as sp_hierarchy
|
21
27
|
|
22
28
|
if TYPE_CHECKING:
|
29
|
+
from collections.abc import Sequence
|
30
|
+
|
23
31
|
import numpyro as npy
|
24
32
|
import toytree as tt
|
25
|
-
from
|
33
|
+
from ete3 import Tree
|
26
34
|
from jax._src.typing import Array
|
35
|
+
from matplotlib.axes import Axes
|
36
|
+
from matplotlib.colors import Colormap
|
37
|
+
from matplotlib.figure import Figure
|
27
38
|
|
28
39
|
config.update("jax_enable_x64", True)
|
29
40
|
|
@@ -179,7 +190,7 @@ class CompositionalModel2(ABC):
|
|
179
190
|
self,
|
180
191
|
sample_adata: AnnData,
|
181
192
|
kernel: npy.infer.mcmc.MCMCKernel,
|
182
|
-
rng_key: Array
|
193
|
+
rng_key: Array,
|
183
194
|
copy: bool = False,
|
184
195
|
*args,
|
185
196
|
**kwargs,
|
@@ -295,7 +306,7 @@ class CompositionalModel2(ABC):
|
|
295
306
|
if copy:
|
296
307
|
sample_adata = sample_adata.copy()
|
297
308
|
|
298
|
-
rng_key_array = random.
|
309
|
+
rng_key_array = random.key(rng_key)
|
299
310
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
|
300
311
|
|
301
312
|
# Set up NUTS kernel
|
@@ -335,7 +346,6 @@ class CompositionalModel2(ABC):
|
|
335
346
|
copy: Return a copy instead of writing to adata. Defaults to False.
|
336
347
|
|
337
348
|
Examples:
|
338
|
-
Example with scCODA:
|
339
349
|
>>> import pertpy as pt
|
340
350
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
341
351
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -358,10 +368,10 @@ class CompositionalModel2(ABC):
|
|
358
368
|
# Set rng key if needed
|
359
369
|
if rng_key is None:
|
360
370
|
rng = np.random.default_rng()
|
361
|
-
rng_key = random.
|
371
|
+
rng_key = random.key(rng.integers(0, 10000))
|
362
372
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
|
363
373
|
else:
|
364
|
-
rng_key = random.
|
374
|
+
rng_key = random.key(rng_key)
|
365
375
|
|
366
376
|
# Set up HMC kernel
|
367
377
|
sample_adata = self.set_init_mcmc_states(
|
@@ -423,7 +433,6 @@ class CompositionalModel2(ABC):
|
|
423
433
|
- Is credible: Boolean indicator whether effect is credible
|
424
434
|
|
425
435
|
Examples:
|
426
|
-
Example with scCODA:
|
427
436
|
>>> import pertpy as pt
|
428
437
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
429
438
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -433,7 +442,6 @@ class CompositionalModel2(ABC):
|
|
433
442
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
434
443
|
>>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
|
435
444
|
"""
|
436
|
-
# Get model and effect selection types
|
437
445
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
438
446
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
439
447
|
|
@@ -548,7 +556,11 @@ class CompositionalModel2(ABC):
|
|
548
556
|
intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
|
549
557
|
intercept_df = intercept_df.rename(
|
550
558
|
columns=dict(
|
551
|
-
zip(
|
559
|
+
zip(
|
560
|
+
intercept_df.columns,
|
561
|
+
["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
|
562
|
+
strict=False,
|
563
|
+
)
|
552
564
|
)
|
553
565
|
)
|
554
566
|
|
@@ -561,6 +573,7 @@ class CompositionalModel2(ABC):
|
|
561
573
|
zip(
|
562
574
|
effect_df.columns,
|
563
575
|
["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
|
576
|
+
strict=False,
|
564
577
|
)
|
565
578
|
)
|
566
579
|
)
|
@@ -581,6 +594,7 @@ class CompositionalModel2(ABC):
|
|
581
594
|
"Expected Sample",
|
582
595
|
"log2-fold change",
|
583
596
|
],
|
597
|
+
strict=False,
|
584
598
|
)
|
585
599
|
)
|
586
600
|
)
|
@@ -594,6 +608,7 @@ class CompositionalModel2(ABC):
|
|
594
608
|
zip(
|
595
609
|
node_df.columns,
|
596
610
|
["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
|
611
|
+
strict=False,
|
597
612
|
)
|
598
613
|
) # type: ignore
|
599
614
|
) # type: ignore
|
@@ -781,7 +796,6 @@ class CompositionalModel2(ABC):
|
|
781
796
|
kwargs: Passed to az.summary
|
782
797
|
|
783
798
|
Examples:
|
784
|
-
Example with scCODA:
|
785
799
|
>>> import pertpy as pt
|
786
800
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
787
801
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -799,7 +813,7 @@ class CompositionalModel2(ABC):
|
|
799
813
|
raise
|
800
814
|
if isinstance(data, AnnData):
|
801
815
|
sample_adata = data
|
802
|
-
|
816
|
+
|
803
817
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
804
818
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
805
819
|
|
@@ -926,7 +940,6 @@ class CompositionalModel2(ABC):
|
|
926
940
|
pd.DataFrame: Intercept data frame.
|
927
941
|
|
928
942
|
Examples:
|
929
|
-
Example with scCODA:
|
930
943
|
>>> import pertpy as pt
|
931
944
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
932
945
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -936,7 +949,6 @@ class CompositionalModel2(ABC):
|
|
936
949
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
937
950
|
>>> intercepts = sccoda.get_intercept_df(mdata)
|
938
951
|
"""
|
939
|
-
|
940
952
|
if isinstance(data, MuData):
|
941
953
|
try:
|
942
954
|
sample_adata = data[modality_key]
|
@@ -959,7 +971,6 @@ class CompositionalModel2(ABC):
|
|
959
971
|
pd.DataFrame: Effect data frame.
|
960
972
|
|
961
973
|
Examples:
|
962
|
-
Example with scCODA:
|
963
974
|
>>> import pertpy as pt
|
964
975
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
965
976
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -969,7 +980,6 @@ class CompositionalModel2(ABC):
|
|
969
980
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
970
981
|
>>> effects = sccoda.get_effect_df(mdata)
|
971
982
|
"""
|
972
|
-
|
973
983
|
if isinstance(data, MuData):
|
974
984
|
try:
|
975
985
|
sample_adata = data[modality_key]
|
@@ -1003,9 +1013,8 @@ class CompositionalModel2(ABC):
|
|
1003
1013
|
pd.DataFrame: Node effect data frame.
|
1004
1014
|
|
1005
1015
|
Examples:
|
1006
|
-
Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
|
1007
1016
|
>>> import pertpy as pt
|
1008
|
-
>>> adata = pt.dt.
|
1017
|
+
>>> adata = pt.dt.tasccoda_example()
|
1009
1018
|
>>> tasccoda = pt.tl.Tasccoda()
|
1010
1019
|
>>> mdata = tasccoda.load(
|
1011
1020
|
>>> adata, type="sample_level",
|
@@ -1113,6 +1122,1136 @@ class CompositionalModel2(ABC):
|
|
1113
1122
|
|
1114
1123
|
return out
|
1115
1124
|
|
1125
|
+
def _stackbar( # pragma: no cover
|
1126
|
+
self,
|
1127
|
+
y: np.ndarray,
|
1128
|
+
type_names: list[str],
|
1129
|
+
title: str,
|
1130
|
+
level_names: list[str],
|
1131
|
+
figsize: tuple[float, float] | None = None,
|
1132
|
+
dpi: int | None = 100,
|
1133
|
+
palette: ListedColormap | None = cm.tab20,
|
1134
|
+
show_legend: bool | None = True,
|
1135
|
+
) -> plt.Axes:
|
1136
|
+
"""Plots a stacked barplot for one (discrete) covariate.
|
1137
|
+
|
1138
|
+
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
1139
|
+
|
1140
|
+
Args:
|
1141
|
+
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
1142
|
+
one for each group, containing the count mean of each cell type
|
1143
|
+
type_names: The names of all cell types
|
1144
|
+
title: Plot title, usually the covariate's name
|
1145
|
+
level_names: Names of the covariate's levels
|
1146
|
+
figsize: Figure size. Defaults to None.
|
1147
|
+
dpi: Dpi setting. Defaults to 100.
|
1148
|
+
palette: The color map for the barplot. Defaults to cm.tab20.
|
1149
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1150
|
+
|
1151
|
+
Returns:
|
1152
|
+
A :class:`~matplotlib.axes.Axes` object
|
1153
|
+
"""
|
1154
|
+
n_bars, n_types = y.shape
|
1155
|
+
|
1156
|
+
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
1157
|
+
|
1158
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1159
|
+
r = np.array(range(n_bars))
|
1160
|
+
sample_sums = np.sum(y, axis=1)
|
1161
|
+
|
1162
|
+
barwidth = 0.85
|
1163
|
+
cum_bars = np.zeros(n_bars)
|
1164
|
+
|
1165
|
+
for n in range(n_types):
|
1166
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
1167
|
+
plt.bar(
|
1168
|
+
r,
|
1169
|
+
bars,
|
1170
|
+
bottom=cum_bars,
|
1171
|
+
color=palette(n % palette.N),
|
1172
|
+
width=barwidth,
|
1173
|
+
label=type_names[n],
|
1174
|
+
linewidth=0,
|
1175
|
+
)
|
1176
|
+
cum_bars += bars
|
1177
|
+
|
1178
|
+
ax.set_title(title)
|
1179
|
+
if show_legend:
|
1180
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
1181
|
+
ax.set_xticks(r)
|
1182
|
+
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
1183
|
+
ax.set_ylabel("Proportion")
|
1184
|
+
|
1185
|
+
return ax
|
1186
|
+
|
1187
|
+
def plot_stacked_barplot( # pragma: no cover
|
1188
|
+
self,
|
1189
|
+
data: AnnData | MuData,
|
1190
|
+
feature_name: str,
|
1191
|
+
modality_key: str = "coda",
|
1192
|
+
palette: ListedColormap | None = cm.tab20,
|
1193
|
+
show_legend: bool | None = True,
|
1194
|
+
level_order: list[str] = None,
|
1195
|
+
figsize: tuple[float, float] | None = None,
|
1196
|
+
dpi: int | None = 100,
|
1197
|
+
return_fig: bool | None = None,
|
1198
|
+
ax: plt.Axes | None = None,
|
1199
|
+
show: bool | None = None,
|
1200
|
+
save: str | bool | None = None,
|
1201
|
+
**kwargs,
|
1202
|
+
) -> plt.Axes | plt.Figure | None:
|
1203
|
+
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
1204
|
+
|
1205
|
+
Args:
|
1206
|
+
data: AnnData object or MuData object.
|
1207
|
+
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
1208
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1209
|
+
figsize: Figure size. Defaults to None.
|
1210
|
+
dpi: Dpi setting. Defaults to 100.
|
1211
|
+
palette: The matplotlib color map for the barplot. Defaults to cm.tab20.
|
1212
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1213
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1214
|
+
|
1215
|
+
Returns:
|
1216
|
+
A :class:`~matplotlib.axes.Axes` object
|
1217
|
+
|
1218
|
+
Examples:
|
1219
|
+
>>> import pertpy as pt
|
1220
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1221
|
+
>>> sccoda = pt.tl.Sccoda()
|
1222
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1223
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1224
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
1225
|
+
|
1226
|
+
Preview:
|
1227
|
+
.. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
|
1228
|
+
"""
|
1229
|
+
if isinstance(data, MuData):
|
1230
|
+
data = data[modality_key]
|
1231
|
+
if isinstance(data, AnnData):
|
1232
|
+
data = data
|
1233
|
+
|
1234
|
+
ct_names = data.var.index
|
1235
|
+
|
1236
|
+
# option to plot one stacked barplot per sample
|
1237
|
+
if feature_name == "samples":
|
1238
|
+
if level_order:
|
1239
|
+
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
1240
|
+
data = data[level_order]
|
1241
|
+
ax = self._stackbar(
|
1242
|
+
data.X,
|
1243
|
+
type_names=data.var.index,
|
1244
|
+
title="samples",
|
1245
|
+
level_names=data.obs.index,
|
1246
|
+
figsize=figsize,
|
1247
|
+
dpi=dpi,
|
1248
|
+
palette=palette,
|
1249
|
+
show_legend=show_legend,
|
1250
|
+
)
|
1251
|
+
else:
|
1252
|
+
# Order levels
|
1253
|
+
if level_order:
|
1254
|
+
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
1255
|
+
levels = level_order
|
1256
|
+
elif hasattr(data.obs[feature_name], "cat"):
|
1257
|
+
levels = data.obs[feature_name].cat.categories.to_list()
|
1258
|
+
else:
|
1259
|
+
levels = pd.unique(data.obs[feature_name])
|
1260
|
+
n_levels = len(levels)
|
1261
|
+
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
1262
|
+
|
1263
|
+
for level in range(n_levels):
|
1264
|
+
l_indices = np.where(data.obs[feature_name] == levels[level])
|
1265
|
+
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
1266
|
+
|
1267
|
+
ax = self._stackbar(
|
1268
|
+
feature_totals,
|
1269
|
+
type_names=ct_names,
|
1270
|
+
title=feature_name,
|
1271
|
+
level_names=levels,
|
1272
|
+
figsize=figsize,
|
1273
|
+
dpi=dpi,
|
1274
|
+
palette=palette,
|
1275
|
+
show_legend=show_legend,
|
1276
|
+
)
|
1277
|
+
|
1278
|
+
if save:
|
1279
|
+
plt.savefig(save, bbox_inches="tight")
|
1280
|
+
if show:
|
1281
|
+
plt.show()
|
1282
|
+
if return_fig:
|
1283
|
+
return plt.gcf()
|
1284
|
+
if not (show or save):
|
1285
|
+
return ax
|
1286
|
+
return None
|
1287
|
+
|
1288
|
+
def plot_effects_barplot( # pragma: no cover
|
1289
|
+
self,
|
1290
|
+
data: AnnData | MuData,
|
1291
|
+
modality_key: str = "coda",
|
1292
|
+
covariates: str | list | None = None,
|
1293
|
+
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
1294
|
+
plot_facets: bool = True,
|
1295
|
+
plot_zero_covariate: bool = True,
|
1296
|
+
plot_zero_cell_type: bool = False,
|
1297
|
+
palette: str | ListedColormap | None = cm.tab20,
|
1298
|
+
level_order: list[str] = None,
|
1299
|
+
args_barplot: dict | None = None,
|
1300
|
+
figsize: tuple[float, float] | None = None,
|
1301
|
+
dpi: int | None = 100,
|
1302
|
+
return_fig: bool | None = None,
|
1303
|
+
ax: plt.Axes | None = None,
|
1304
|
+
show: bool | None = None,
|
1305
|
+
save: str | bool | None = None,
|
1306
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1307
|
+
"""Barplot visualization for effects.
|
1308
|
+
|
1309
|
+
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
1310
|
+
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
1311
|
+
|
1312
|
+
Args:
|
1313
|
+
data: AnnData object or MuData object.
|
1314
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1315
|
+
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
1316
|
+
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
1317
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1318
|
+
Defaults to True.
|
1319
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
1320
|
+
Defaults to True.
|
1321
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
1322
|
+
Defaults to False.
|
1323
|
+
figsize: Figure size. Defaults to None.
|
1324
|
+
dpi: Figure size. Defaults to 100.
|
1325
|
+
palette: The seaborn color map for the barplot. Defaults to cm.tab20.
|
1326
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1327
|
+
args_barplot: Arguments passed to sns.barplot. Defaults to None.
|
1328
|
+
|
1329
|
+
Returns:
|
1330
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1331
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1332
|
+
|
1333
|
+
Examples:
|
1334
|
+
>>> import pertpy as pt
|
1335
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1336
|
+
>>> sccoda = pt.tl.Sccoda()
|
1337
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1338
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1339
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1340
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1341
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
1342
|
+
|
1343
|
+
Preview:
|
1344
|
+
.. image:: /_static/docstring_previews/sccoda_effects_barplot.png
|
1345
|
+
"""
|
1346
|
+
if args_barplot is None:
|
1347
|
+
args_barplot = {}
|
1348
|
+
if isinstance(data, MuData):
|
1349
|
+
data = data[modality_key]
|
1350
|
+
if isinstance(data, AnnData):
|
1351
|
+
data = data
|
1352
|
+
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
1353
|
+
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
1354
|
+
if covariates is not None:
|
1355
|
+
if isinstance(covariates, str):
|
1356
|
+
covariates = [covariates]
|
1357
|
+
partial_covariate_names = [
|
1358
|
+
covariate_name
|
1359
|
+
for covariate_name in covariate_names
|
1360
|
+
if any(covariate in covariate_name for covariate in covariates)
|
1361
|
+
]
|
1362
|
+
covariate_names = partial_covariate_names
|
1363
|
+
covariate_names_non_zero = [
|
1364
|
+
covariate_name
|
1365
|
+
for covariate_name in covariate_names
|
1366
|
+
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
1367
|
+
]
|
1368
|
+
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
1369
|
+
if not plot_zero_covariate:
|
1370
|
+
covariate_names = covariate_names_non_zero
|
1371
|
+
|
1372
|
+
# set up df for plotting
|
1373
|
+
plot_df = pd.concat(
|
1374
|
+
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
1375
|
+
axis=1,
|
1376
|
+
)
|
1377
|
+
plot_df.columns = covariate_names
|
1378
|
+
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
1379
|
+
|
1380
|
+
plot_df = plot_df.reset_index()
|
1381
|
+
|
1382
|
+
if len(covariate_names_zero) != 0:
|
1383
|
+
if plot_facets:
|
1384
|
+
if plot_zero_covariate and not plot_zero_cell_type:
|
1385
|
+
plot_df = plot_df[plot_df["value"] != 0]
|
1386
|
+
for covariate_name_zero in covariate_names_zero:
|
1387
|
+
new_row = {
|
1388
|
+
"Covariate": covariate_name_zero,
|
1389
|
+
"Cell Type": "zero",
|
1390
|
+
"value": 0,
|
1391
|
+
}
|
1392
|
+
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
1393
|
+
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1394
|
+
plot_df = plot_df.sort_values(["covariate_"])
|
1395
|
+
if not plot_zero_cell_type:
|
1396
|
+
cell_type_names_zero = [
|
1397
|
+
name
|
1398
|
+
for name in plot_df["Cell Type"].unique()
|
1399
|
+
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
1400
|
+
]
|
1401
|
+
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
1402
|
+
|
1403
|
+
# If plot as facets, create a FacetGrid and map barplot to it.
|
1404
|
+
if plot_facets:
|
1405
|
+
if isinstance(palette, ListedColormap):
|
1406
|
+
palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
|
1407
|
+
if figsize is not None:
|
1408
|
+
height = figsize[0]
|
1409
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1410
|
+
else:
|
1411
|
+
height = 3
|
1412
|
+
aspect = 2
|
1413
|
+
|
1414
|
+
g = sns.FacetGrid(
|
1415
|
+
plot_df,
|
1416
|
+
col="Covariate",
|
1417
|
+
sharey=True,
|
1418
|
+
sharex=False,
|
1419
|
+
height=height,
|
1420
|
+
aspect=aspect,
|
1421
|
+
)
|
1422
|
+
|
1423
|
+
g.map(
|
1424
|
+
sns.barplot,
|
1425
|
+
"Cell Type",
|
1426
|
+
"value",
|
1427
|
+
palette=palette,
|
1428
|
+
order=level_order,
|
1429
|
+
**args_barplot,
|
1430
|
+
)
|
1431
|
+
g.set_xticklabels(rotation=90)
|
1432
|
+
g.set(ylabel=parameter)
|
1433
|
+
axes = g.axes.flatten()
|
1434
|
+
for i, ax in enumerate(axes):
|
1435
|
+
ax.set_title(covariate_names[i])
|
1436
|
+
if len(ax.get_xticklabels()) < 5:
|
1437
|
+
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
1438
|
+
if len(ax.get_xticklabels()) == 1:
|
1439
|
+
if ax.get_xticklabels()[0]._text == "zero":
|
1440
|
+
ax.set_xticks([])
|
1441
|
+
|
1442
|
+
if save:
|
1443
|
+
plt.savefig(save, bbox_inches="tight")
|
1444
|
+
if show:
|
1445
|
+
plt.show()
|
1446
|
+
if return_fig:
|
1447
|
+
return plt.gcf()
|
1448
|
+
if not (show or save):
|
1449
|
+
return g
|
1450
|
+
return None
|
1451
|
+
|
1452
|
+
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
1453
|
+
else:
|
1454
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1455
|
+
if len(covariate_names) == 1:
|
1456
|
+
if isinstance(palette, ListedColormap):
|
1457
|
+
palette = np.array(
|
1458
|
+
[palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
|
1459
|
+
).tolist()
|
1460
|
+
sns.barplot(
|
1461
|
+
data=plot_df,
|
1462
|
+
x="Cell Type",
|
1463
|
+
y="value",
|
1464
|
+
hue="x",
|
1465
|
+
palette=palette,
|
1466
|
+
ax=ax,
|
1467
|
+
)
|
1468
|
+
ax.set_title(covariate_names[0])
|
1469
|
+
else:
|
1470
|
+
if isinstance(palette, ListedColormap):
|
1471
|
+
palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
|
1472
|
+
sns.barplot(
|
1473
|
+
data=plot_df,
|
1474
|
+
x="Cell Type",
|
1475
|
+
y="value",
|
1476
|
+
hue="Covariate",
|
1477
|
+
palette=palette,
|
1478
|
+
ax=ax,
|
1479
|
+
)
|
1480
|
+
cell_types = pd.unique(plot_df["Cell Type"])
|
1481
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1482
|
+
|
1483
|
+
if save:
|
1484
|
+
plt.savefig(save, bbox_inches="tight")
|
1485
|
+
if show:
|
1486
|
+
plt.show()
|
1487
|
+
if return_fig:
|
1488
|
+
return plt.gcf()
|
1489
|
+
if not (show or save):
|
1490
|
+
return ax
|
1491
|
+
return None
|
1492
|
+
|
1493
|
+
def plot_boxplots( # pragma: no cover
|
1494
|
+
self,
|
1495
|
+
data: AnnData | MuData,
|
1496
|
+
feature_name: str,
|
1497
|
+
modality_key: str = "coda",
|
1498
|
+
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
1499
|
+
plot_facets: bool = False,
|
1500
|
+
add_dots: bool = False,
|
1501
|
+
cell_types: list | None = None,
|
1502
|
+
args_boxplot: dict | None = None,
|
1503
|
+
args_swarmplot: dict | None = None,
|
1504
|
+
palette: str | None = "Blues",
|
1505
|
+
show_legend: bool | None = True,
|
1506
|
+
level_order: list[str] = None,
|
1507
|
+
figsize: tuple[float, float] | None = None,
|
1508
|
+
dpi: int | None = 100,
|
1509
|
+
return_fig: bool | None = None,
|
1510
|
+
ax: plt.Axes | None = None,
|
1511
|
+
show: bool | None = None,
|
1512
|
+
save: str | bool | None = None,
|
1513
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1514
|
+
"""Grouped boxplot visualization.
|
1515
|
+
|
1516
|
+
The cell counts for each cell type are shown as a group of boxplots
|
1517
|
+
with intra--group separation by a covariate from data.obs.
|
1518
|
+
|
1519
|
+
Args:
|
1520
|
+
data: AnnData object or MuData object
|
1521
|
+
feature_name: The name of the feature in data.obs to plot
|
1522
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1523
|
+
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
1524
|
+
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
1525
|
+
Defaults to "relative".
|
1526
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
|
1527
|
+
add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
|
1528
|
+
cell_types: Subset of cell types that should be plotted. Defaults to None.
|
1529
|
+
args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
|
1530
|
+
args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
|
1531
|
+
figsize: Figure size. Defaults to None.
|
1532
|
+
dpi: Dpi setting. Defaults to 100.
|
1533
|
+
palette: The seaborn color map for the barplot. Defaults to "Blues".
|
1534
|
+
show_legend: If True, adds a legend. Defaults to True.
|
1535
|
+
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
1536
|
+
|
1537
|
+
Returns:
|
1538
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1539
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1540
|
+
|
1541
|
+
Examples:
|
1542
|
+
>>> import pertpy as pt
|
1543
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1544
|
+
>>> sccoda = pt.tl.Sccoda()
|
1545
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1546
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1547
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
1548
|
+
|
1549
|
+
Preview:
|
1550
|
+
.. image:: /_static/docstring_previews/sccoda_boxplots.png
|
1551
|
+
"""
|
1552
|
+
if args_boxplot is None:
|
1553
|
+
args_boxplot = {}
|
1554
|
+
if args_swarmplot is None:
|
1555
|
+
args_swarmplot = {}
|
1556
|
+
if isinstance(data, MuData):
|
1557
|
+
data = data[modality_key]
|
1558
|
+
if isinstance(data, AnnData):
|
1559
|
+
data = data
|
1560
|
+
# y scale transformations
|
1561
|
+
if y_scale == "relative":
|
1562
|
+
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
1563
|
+
X = data.X / sample_sums
|
1564
|
+
value_name = "Proportion"
|
1565
|
+
# add pseudocount 0.5 if using log scale
|
1566
|
+
elif y_scale == "log":
|
1567
|
+
X = data.X.copy()
|
1568
|
+
X[X == 0] = 0.5
|
1569
|
+
X = np.log(X)
|
1570
|
+
value_name = "log(count)"
|
1571
|
+
elif y_scale == "log10":
|
1572
|
+
X = data.X.copy()
|
1573
|
+
X[X == 0] = 0.5
|
1574
|
+
X = np.log(X)
|
1575
|
+
value_name = "log10(count)"
|
1576
|
+
elif y_scale == "count":
|
1577
|
+
X = data.X
|
1578
|
+
value_name = "count"
|
1579
|
+
else:
|
1580
|
+
raise ValueError("Invalid y_scale transformation")
|
1581
|
+
|
1582
|
+
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
1583
|
+
data.obs[feature_name], left_index=True, right_index=True
|
1584
|
+
)
|
1585
|
+
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
1586
|
+
if cell_types is not None:
|
1587
|
+
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
1588
|
+
|
1589
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1590
|
+
# We had to drop the dependency.
|
1591
|
+
# Get credible effects results from model
|
1592
|
+
# if draw_effects:
|
1593
|
+
# if model is not None:
|
1594
|
+
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
1595
|
+
# else:
|
1596
|
+
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
1597
|
+
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
1598
|
+
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
1599
|
+
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
1600
|
+
|
1601
|
+
# If plot as facets, create a FacetGrid and map boxplot to it.
|
1602
|
+
if plot_facets:
|
1603
|
+
if level_order is None:
|
1604
|
+
level_order = pd.unique(plot_df[feature_name])
|
1605
|
+
|
1606
|
+
K = X.shape[1]
|
1607
|
+
|
1608
|
+
if figsize is not None:
|
1609
|
+
height = figsize[0]
|
1610
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1611
|
+
else:
|
1612
|
+
height = 3
|
1613
|
+
aspect = 2
|
1614
|
+
|
1615
|
+
g = sns.FacetGrid(
|
1616
|
+
plot_df,
|
1617
|
+
col="Cell type",
|
1618
|
+
sharey=False,
|
1619
|
+
col_wrap=int(np.floor(np.sqrt(K))),
|
1620
|
+
height=height,
|
1621
|
+
aspect=aspect,
|
1622
|
+
)
|
1623
|
+
g.map(
|
1624
|
+
sns.boxplot,
|
1625
|
+
feature_name,
|
1626
|
+
value_name,
|
1627
|
+
palette=palette,
|
1628
|
+
order=level_order,
|
1629
|
+
**args_boxplot,
|
1630
|
+
)
|
1631
|
+
|
1632
|
+
if add_dots:
|
1633
|
+
if "hue" in args_swarmplot:
|
1634
|
+
hue = args_swarmplot.pop("hue")
|
1635
|
+
else:
|
1636
|
+
hue = None
|
1637
|
+
|
1638
|
+
if hue is None:
|
1639
|
+
g.map(
|
1640
|
+
sns.swarmplot,
|
1641
|
+
feature_name,
|
1642
|
+
value_name,
|
1643
|
+
color="black",
|
1644
|
+
order=level_order,
|
1645
|
+
**args_swarmplot,
|
1646
|
+
).set_titles("{col_name}")
|
1647
|
+
else:
|
1648
|
+
g.map(
|
1649
|
+
sns.swarmplot,
|
1650
|
+
feature_name,
|
1651
|
+
value_name,
|
1652
|
+
hue,
|
1653
|
+
order=level_order,
|
1654
|
+
**args_swarmplot,
|
1655
|
+
).set_titles("{col_name}")
|
1656
|
+
|
1657
|
+
if save:
|
1658
|
+
plt.savefig(save, bbox_inches="tight")
|
1659
|
+
if show:
|
1660
|
+
plt.show()
|
1661
|
+
if return_fig:
|
1662
|
+
return plt.gcf()
|
1663
|
+
if not (show or save):
|
1664
|
+
return g
|
1665
|
+
return None
|
1666
|
+
|
1667
|
+
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
1668
|
+
else:
|
1669
|
+
if level_order:
|
1670
|
+
args_boxplot["hue_order"] = level_order
|
1671
|
+
args_swarmplot["hue_order"] = level_order
|
1672
|
+
|
1673
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1674
|
+
|
1675
|
+
ax = sns.boxplot(
|
1676
|
+
x="Cell type",
|
1677
|
+
y=value_name,
|
1678
|
+
hue=feature_name,
|
1679
|
+
data=plot_df,
|
1680
|
+
fliersize=1,
|
1681
|
+
palette=palette,
|
1682
|
+
ax=ax,
|
1683
|
+
**args_boxplot,
|
1684
|
+
)
|
1685
|
+
|
1686
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1687
|
+
# We had to drop the dependency.
|
1688
|
+
# if draw_effects:
|
1689
|
+
# pairs = [
|
1690
|
+
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
1691
|
+
# for _, row in credible_effects_df.iterrows()
|
1692
|
+
# ]
|
1693
|
+
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
1694
|
+
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
1695
|
+
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
1696
|
+
# annot.annotate()
|
1697
|
+
|
1698
|
+
if add_dots:
|
1699
|
+
sns.swarmplot(
|
1700
|
+
x="Cell type",
|
1701
|
+
y=value_name,
|
1702
|
+
data=plot_df,
|
1703
|
+
hue=feature_name,
|
1704
|
+
ax=ax,
|
1705
|
+
dodge=True,
|
1706
|
+
palette="dark:black",
|
1707
|
+
**args_swarmplot,
|
1708
|
+
)
|
1709
|
+
|
1710
|
+
cell_types = pd.unique(plot_df["Cell type"])
|
1711
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1712
|
+
|
1713
|
+
if show_legend:
|
1714
|
+
handles, labels = ax.get_legend_handles_labels()
|
1715
|
+
handout = []
|
1716
|
+
labelout = []
|
1717
|
+
for h, l in zip(handles, labels, strict=False):
|
1718
|
+
if l not in labelout:
|
1719
|
+
labelout.append(l)
|
1720
|
+
handout.append(h)
|
1721
|
+
ax.legend(
|
1722
|
+
handout,
|
1723
|
+
labelout,
|
1724
|
+
loc="upper left",
|
1725
|
+
bbox_to_anchor=(1, 1),
|
1726
|
+
ncol=1,
|
1727
|
+
title=feature_name,
|
1728
|
+
)
|
1729
|
+
|
1730
|
+
if save:
|
1731
|
+
plt.savefig(save, bbox_inches="tight")
|
1732
|
+
if show:
|
1733
|
+
plt.show()
|
1734
|
+
if return_fig:
|
1735
|
+
return plt.gcf()
|
1736
|
+
if not (show or save):
|
1737
|
+
return ax
|
1738
|
+
return None
|
1739
|
+
|
1740
|
+
def plot_rel_abundance_dispersion_plot( # pragma: no cover
|
1741
|
+
self,
|
1742
|
+
data: AnnData | MuData,
|
1743
|
+
modality_key: str = "coda",
|
1744
|
+
abundant_threshold: float | None = 0.9,
|
1745
|
+
default_color: str | None = "Grey",
|
1746
|
+
abundant_color: str | None = "Red",
|
1747
|
+
label_cell_types: bool = True,
|
1748
|
+
figsize: tuple[float, float] | None = None,
|
1749
|
+
dpi: int | None = 100,
|
1750
|
+
return_fig: bool | None = None,
|
1751
|
+
ax: plt.Axes | None = None,
|
1752
|
+
show: bool | None = None,
|
1753
|
+
save: str | bool | None = None,
|
1754
|
+
) -> plt.Axes | plt.Figure | None:
|
1755
|
+
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
1756
|
+
|
1757
|
+
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
1758
|
+
|
1759
|
+
Args:
|
1760
|
+
data: AnnData or MuData object.
|
1761
|
+
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
1762
|
+
Defaults to "coda".
|
1763
|
+
abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
|
1764
|
+
default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
|
1765
|
+
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
1766
|
+
Defaults to "Red".
|
1767
|
+
label_cell_types: Label dots with cell type names. Defaults to True.
|
1768
|
+
figsize: Figure size. Defaults to None.
|
1769
|
+
dpi: Dpi setting. Defaults to 100.
|
1770
|
+
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
1771
|
+
|
1772
|
+
Returns:
|
1773
|
+
A :class:`~matplotlib.axes.Axes` object
|
1774
|
+
|
1775
|
+
Examples:
|
1776
|
+
>>> import pertpy as pt
|
1777
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1778
|
+
>>> sccoda = pt.tl.Sccoda()
|
1779
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1780
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1781
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1782
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1783
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
1784
|
+
|
1785
|
+
Preview:
|
1786
|
+
.. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
|
1787
|
+
"""
|
1788
|
+
if isinstance(data, MuData):
|
1789
|
+
data = data[modality_key]
|
1790
|
+
if isinstance(data, AnnData):
|
1791
|
+
data = data
|
1792
|
+
if ax is None:
|
1793
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1794
|
+
|
1795
|
+
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
1796
|
+
|
1797
|
+
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
1798
|
+
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
1799
|
+
|
1800
|
+
# select reference
|
1801
|
+
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
1802
|
+
|
1803
|
+
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
1804
|
+
|
1805
|
+
# Scatterplot
|
1806
|
+
plot_df = pd.DataFrame(
|
1807
|
+
{
|
1808
|
+
"Total dispersion": cell_type_disp,
|
1809
|
+
"Cell type": data.var.index,
|
1810
|
+
"Presence": 1 - percent_zero,
|
1811
|
+
"Is abundant": is_abundant,
|
1812
|
+
}
|
1813
|
+
)
|
1814
|
+
|
1815
|
+
if len(np.unique(plot_df["Is abundant"])) > 1:
|
1816
|
+
palette = [default_color, abundant_color]
|
1817
|
+
elif np.unique(plot_df["Is abundant"]) == [False]:
|
1818
|
+
palette = [default_color]
|
1819
|
+
else:
|
1820
|
+
palette = [abundant_color]
|
1821
|
+
|
1822
|
+
ax = sns.scatterplot(
|
1823
|
+
data=plot_df,
|
1824
|
+
x="Presence",
|
1825
|
+
y="Total dispersion",
|
1826
|
+
hue="Is abundant",
|
1827
|
+
palette=palette,
|
1828
|
+
ax=ax,
|
1829
|
+
)
|
1830
|
+
|
1831
|
+
# Text labels for abundant cell types
|
1832
|
+
|
1833
|
+
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
1834
|
+
|
1835
|
+
def label_point(x, y, val, ax):
|
1836
|
+
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
1837
|
+
texts = [
|
1838
|
+
ax.text(
|
1839
|
+
point["x"],
|
1840
|
+
point["y"],
|
1841
|
+
str(point["val"]),
|
1842
|
+
)
|
1843
|
+
for i, point in a.iterrows()
|
1844
|
+
]
|
1845
|
+
adjust_text(texts)
|
1846
|
+
|
1847
|
+
if label_cell_types:
|
1848
|
+
label_point(
|
1849
|
+
abundant_df["Presence"],
|
1850
|
+
abundant_df["Total dispersion"],
|
1851
|
+
abundant_df["Cell type"],
|
1852
|
+
plt.gca(),
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1856
|
+
|
1857
|
+
if save:
|
1858
|
+
plt.savefig(save, bbox_inches="tight")
|
1859
|
+
if show:
|
1860
|
+
plt.show()
|
1861
|
+
if return_fig:
|
1862
|
+
return plt.gcf()
|
1863
|
+
if not (show or save):
|
1864
|
+
return ax
|
1865
|
+
return None
|
1866
|
+
|
1867
|
+
def plot_draw_tree( # pragma: no cover
|
1868
|
+
self,
|
1869
|
+
data: AnnData | MuData,
|
1870
|
+
modality_key: str = "coda",
|
1871
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1872
|
+
tight_text: bool | None = False,
|
1873
|
+
show_scale: bool | None = False,
|
1874
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1875
|
+
figsize: tuple[float, float] | None = (None, None),
|
1876
|
+
dpi: int | None = 100,
|
1877
|
+
show: bool | None = True,
|
1878
|
+
save: str | bool | None = None,
|
1879
|
+
) -> Tree | None:
|
1880
|
+
"""Plot a tree using input ete3 tree object.
|
1881
|
+
|
1882
|
+
Args:
|
1883
|
+
data: AnnData object or MuData object.
|
1884
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1885
|
+
Defaults to "coda".
|
1886
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1887
|
+
Defaults to "tree".
|
1888
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1889
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1890
|
+
Default to False.
|
1891
|
+
show_scale: Include the scale legend in the tree image or not.
|
1892
|
+
Defaults to False.
|
1893
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1894
|
+
Defaults to True.
|
1895
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
1896
|
+
Output image can be saved whether show is True or not.
|
1897
|
+
Defaults to None.
|
1898
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
1899
|
+
figsize: Figure size. Defaults to None.
|
1900
|
+
dpi: Dots per inches. Defaults to 100.
|
1901
|
+
|
1902
|
+
Returns:
|
1903
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
1904
|
+
|
1905
|
+
Examples:
|
1906
|
+
>>> import pertpy as pt
|
1907
|
+
>>> adata = pt.dt.tasccoda_example()
|
1908
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1909
|
+
>>> mdata = tasccoda.load(
|
1910
|
+
>>> adata, type="sample_level",
|
1911
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1912
|
+
>>> key_added="lineage", add_level_name=True
|
1913
|
+
>>> )
|
1914
|
+
>>> mdata = tasccoda.prepare(
|
1915
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1916
|
+
>>> )
|
1917
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1918
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
1919
|
+
|
1920
|
+
Preview:
|
1921
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_tree.png
|
1922
|
+
"""
|
1923
|
+
try:
|
1924
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
1925
|
+
except ImportError:
|
1926
|
+
raise ImportError(
|
1927
|
+
"To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
|
1928
|
+
) from None
|
1929
|
+
|
1930
|
+
if isinstance(data, MuData):
|
1931
|
+
data = data[modality_key]
|
1932
|
+
if isinstance(data, AnnData):
|
1933
|
+
data = data
|
1934
|
+
if isinstance(tree, str):
|
1935
|
+
tree = data.uns[tree]
|
1936
|
+
|
1937
|
+
def my_layout(node):
|
1938
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
1939
|
+
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
1940
|
+
|
1941
|
+
tree_style = TreeStyle()
|
1942
|
+
tree_style.show_leaf_name = False
|
1943
|
+
tree_style.layout_fn = my_layout
|
1944
|
+
tree_style.show_scale = show_scale
|
1945
|
+
|
1946
|
+
if save is not None:
|
1947
|
+
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1948
|
+
if show:
|
1949
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1950
|
+
else:
|
1951
|
+
return tree, tree_style
|
1952
|
+
|
1953
|
+
def plot_draw_effects( # pragma: no cover
|
1954
|
+
self,
|
1955
|
+
data: AnnData | MuData,
|
1956
|
+
covariate: str,
|
1957
|
+
modality_key: str = "coda",
|
1958
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1959
|
+
show_legend: bool | None = None,
|
1960
|
+
show_leaf_effects: bool | None = False,
|
1961
|
+
tight_text: bool | None = False,
|
1962
|
+
show_scale: bool | None = False,
|
1963
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1964
|
+
figsize: tuple[float, float] | None = (None, None),
|
1965
|
+
dpi: int | None = 100,
|
1966
|
+
show: bool | None = True,
|
1967
|
+
save: str | None = None,
|
1968
|
+
) -> Tree | None:
|
1969
|
+
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
1970
|
+
|
1971
|
+
Args:
|
1972
|
+
data: AnnData object or MuData object.
|
1973
|
+
covariate: The covariate, whose effects should be plotted.
|
1974
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1975
|
+
Defaults to "coda".
|
1976
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1977
|
+
Defaults to "tree".
|
1978
|
+
show_legend: If show legend of nodes significant effects or not.
|
1979
|
+
Defaults to False if show_leaf_effects is True.
|
1980
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
1981
|
+
Defaults to False.
|
1982
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1983
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1984
|
+
Defaults to False.
|
1985
|
+
show_scale: Include the scale legend in the tree image or not. Defaults to False.
|
1986
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
1987
|
+
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
1988
|
+
Defaults to None.
|
1989
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
1990
|
+
figsize: Figure size. Defaults to None.
|
1991
|
+
dpi: Dots per inches. Defaults to 100.
|
1992
|
+
|
1993
|
+
Returns:
|
1994
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
1995
|
+
or plot the tree inline (`show = False`)
|
1996
|
+
|
1997
|
+
Examples:
|
1998
|
+
>>> import pertpy as pt
|
1999
|
+
>>> adata = pt.dt.tasccoda_example()
|
2000
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
2001
|
+
>>> mdata = tasccoda.load(
|
2002
|
+
>>> adata, type="sample_level",
|
2003
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
2004
|
+
>>> key_added="lineage", add_level_name=True
|
2005
|
+
>>> )
|
2006
|
+
>>> mdata = tasccoda.prepare(
|
2007
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
2008
|
+
>>> )
|
2009
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
2010
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
2011
|
+
|
2012
|
+
Preview:
|
2013
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_effects.png
|
2014
|
+
"""
|
2015
|
+
try:
|
2016
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
2017
|
+
except ImportError:
|
2018
|
+
raise ImportError(
|
2019
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2020
|
+
) from None
|
2021
|
+
|
2022
|
+
if isinstance(data, MuData):
|
2023
|
+
data = data[modality_key]
|
2024
|
+
if isinstance(data, AnnData):
|
2025
|
+
data = data
|
2026
|
+
if show_legend is None:
|
2027
|
+
show_legend = not show_leaf_effects
|
2028
|
+
elif show_legend:
|
2029
|
+
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
2030
|
+
|
2031
|
+
if isinstance(tree, str):
|
2032
|
+
tree = data.uns[tree]
|
2033
|
+
# Collapse tree singularities
|
2034
|
+
tree2 = collapse_singularities_2(tree)
|
2035
|
+
|
2036
|
+
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
2037
|
+
node_effs.index = node_effs.index.get_level_values("Node")
|
2038
|
+
|
2039
|
+
covariates = data.uns["scCODA_params"]["covariate_names"]
|
2040
|
+
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
2041
|
+
eff_df = pd.concat(effect_dfs)
|
2042
|
+
eff_df.index = pd.MultiIndex.from_product(
|
2043
|
+
(covariates, data.var.index.tolist()),
|
2044
|
+
names=["Covariate", "Cell Type"],
|
2045
|
+
)
|
2046
|
+
leaf_effs = eff_df.loc[(covariate,),].copy()
|
2047
|
+
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
2048
|
+
|
2049
|
+
# Add effect values
|
2050
|
+
for n in tree2.traverse():
|
2051
|
+
nstyle = NodeStyle()
|
2052
|
+
nstyle["size"] = 0
|
2053
|
+
n.set_style(nstyle)
|
2054
|
+
if n.name in node_effs.index:
|
2055
|
+
e = node_effs.loc[n.name, "Final Parameter"]
|
2056
|
+
n.add_feature("node_effect", e)
|
2057
|
+
else:
|
2058
|
+
n.add_feature("node_effect", 0)
|
2059
|
+
if n.name in leaf_effs.index:
|
2060
|
+
e = leaf_effs.loc[n.name, "Effect"]
|
2061
|
+
n.add_feature("leaf_effect", e)
|
2062
|
+
else:
|
2063
|
+
n.add_feature("leaf_effect", 0)
|
2064
|
+
|
2065
|
+
# Scale effect values to get nice node sizes
|
2066
|
+
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
2067
|
+
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
2068
|
+
|
2069
|
+
def my_layout(node):
|
2070
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
2071
|
+
text_face.margin_left = 10
|
2072
|
+
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
2073
|
+
|
2074
|
+
# if node.is_leaf():
|
2075
|
+
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
2076
|
+
if np.sign(node.node_effect) == 1:
|
2077
|
+
color = "blue"
|
2078
|
+
elif np.sign(node.node_effect) == -1:
|
2079
|
+
color = "red"
|
2080
|
+
else:
|
2081
|
+
color = "cyan"
|
2082
|
+
if size != 0:
|
2083
|
+
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
2084
|
+
|
2085
|
+
tree_style = TreeStyle()
|
2086
|
+
tree_style.show_leaf_name = False
|
2087
|
+
tree_style.layout_fn = my_layout
|
2088
|
+
tree_style.show_scale = show_scale
|
2089
|
+
tree_style.draw_guiding_lines = True
|
2090
|
+
tree_style.legend_position = 1
|
2091
|
+
|
2092
|
+
if show_legend:
|
2093
|
+
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
2094
|
+
tree_style.legend.add_face(TextFace(" "), column=1)
|
2095
|
+
for i in range(4, 0, -1):
|
2096
|
+
tree_style.legend.add_face(
|
2097
|
+
CircleFace(
|
2098
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2099
|
+
"red",
|
2100
|
+
),
|
2101
|
+
column=0,
|
2102
|
+
)
|
2103
|
+
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
2104
|
+
tree_style.legend.add_face(
|
2105
|
+
CircleFace(
|
2106
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2107
|
+
"blue",
|
2108
|
+
),
|
2109
|
+
column=1,
|
2110
|
+
)
|
2111
|
+
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
2112
|
+
|
2113
|
+
if show_leaf_effects:
|
2114
|
+
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
2115
|
+
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
2116
|
+
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
2117
|
+
|
2118
|
+
dir_path = Path.cwd()
|
2119
|
+
dir_path = Path(dir_path / "tree_effect.png")
|
2120
|
+
tree2.render(dir_path, tree_style=tree_style, units="in")
|
2121
|
+
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
2122
|
+
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
2123
|
+
img = mpimg.imread(dir_path)
|
2124
|
+
ax[0].imshow(img)
|
2125
|
+
ax[0].get_xaxis().set_visible(False)
|
2126
|
+
ax[0].get_yaxis().set_visible(False)
|
2127
|
+
ax[0].set_frame_on(False)
|
2128
|
+
|
2129
|
+
ax[1].get_yaxis().set_visible(False)
|
2130
|
+
ax[1].spines["left"].set_visible(False)
|
2131
|
+
ax[1].spines["right"].set_visible(False)
|
2132
|
+
ax[1].spines["top"].set_visible(False)
|
2133
|
+
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
2134
|
+
plt.subplots_adjust(wspace=0)
|
2135
|
+
|
2136
|
+
if save is not None:
|
2137
|
+
plt.savefig(save)
|
2138
|
+
|
2139
|
+
if save is not None and not show_leaf_effects:
|
2140
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2141
|
+
if show:
|
2142
|
+
if not show_leaf_effects:
|
2143
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2144
|
+
else:
|
2145
|
+
if not show_leaf_effects:
|
2146
|
+
return tree2, tree_style
|
2147
|
+
return None
|
2148
|
+
|
2149
|
+
def plot_effects_umap( # pragma: no cover
|
2150
|
+
self,
|
2151
|
+
mdata: MuData,
|
2152
|
+
effect_name: str | list | None,
|
2153
|
+
cluster_key: str,
|
2154
|
+
modality_key_1: str = "rna",
|
2155
|
+
modality_key_2: str = "coda",
|
2156
|
+
color_map: Colormap | str | None = None,
|
2157
|
+
palette: str | Sequence[str] | None = None,
|
2158
|
+
return_fig: bool | None = None,
|
2159
|
+
ax: Axes = None,
|
2160
|
+
show: bool = None,
|
2161
|
+
save: str | bool | None = None,
|
2162
|
+
**kwargs,
|
2163
|
+
) -> plt.Axes | plt.Figure | None:
|
2164
|
+
"""Plot a UMAP visualization colored by effect strength.
|
2165
|
+
|
2166
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
2167
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
2168
|
+
|
2169
|
+
Args:
|
2170
|
+
mudata: MuData object.
|
2171
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
2172
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
2173
|
+
To assign cell types' effects to original cells.
|
2174
|
+
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
2175
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
2176
|
+
Defaults to "coda".
|
2177
|
+
show: Whether to display the figure or return axis. Defaults to None.
|
2178
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
2179
|
+
Defaults to None.
|
2180
|
+
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
2181
|
+
|
2182
|
+
Returns:
|
2183
|
+
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
2184
|
+
|
2185
|
+
Examples:
|
2186
|
+
>>> import pertpy as pt
|
2187
|
+
>>> import scanpy as sc
|
2188
|
+
>>> import schist
|
2189
|
+
>>> adata = pt.dt.haber_2017_regions()
|
2190
|
+
>>> sc.pp.neighbors(adata)
|
2191
|
+
>>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
|
2192
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
2193
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
2194
|
+
>>> cell_type_identifier="nsbm_level_1",
|
2195
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
2196
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
2197
|
+
>>> add_level_name=True)
|
2198
|
+
>>> tasccoda_model.prepare(
|
2199
|
+
>>> tasccoda_data,
|
2200
|
+
>>> modality_key="coda",
|
2201
|
+
>>> reference_cell_type="18",
|
2202
|
+
>>> formula="condition",
|
2203
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
2204
|
+
>>> tree_key="tree"
|
2205
|
+
>>> )
|
2206
|
+
>>> tasccoda_model.run_nuts(
|
2207
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2208
|
+
... )
|
2209
|
+
>>> tasccoda_model.run_nuts(
|
2210
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2211
|
+
... )
|
2212
|
+
>>> sc.tl.umap(tasccoda_data["rna"])
|
2213
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
2214
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
2215
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
2216
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
2217
|
+
>>> cluster_key="nsbm_level_1",
|
2218
|
+
>>> )
|
2219
|
+
|
2220
|
+
Preview:
|
2221
|
+
.. image:: /_static/docstring_previews/tasccoda_effects_umap.png
|
2222
|
+
"""
|
2223
|
+
# TODO: Add effect_name parameter and cluster_key and test the example
|
2224
|
+
data_rna = mdata[modality_key_1]
|
2225
|
+
data_coda = mdata[modality_key_2]
|
2226
|
+
if isinstance(effect_name, str):
|
2227
|
+
effect_name = [effect_name]
|
2228
|
+
for _, effect in enumerate(effect_name):
|
2229
|
+
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
2230
|
+
if kwargs.get("vmin"):
|
2231
|
+
vmin = kwargs["vmin"]
|
2232
|
+
kwargs.pop("vmin")
|
2233
|
+
else:
|
2234
|
+
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
2235
|
+
if kwargs.get("vmax"):
|
2236
|
+
vmax = kwargs["vmax"]
|
2237
|
+
kwargs.pop("vmax")
|
2238
|
+
else:
|
2239
|
+
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
2240
|
+
|
2241
|
+
return sc.pl.umap(
|
2242
|
+
data_rna,
|
2243
|
+
color=effect_name,
|
2244
|
+
vmax=vmax,
|
2245
|
+
vmin=vmin,
|
2246
|
+
palette=palette,
|
2247
|
+
color_map=color_map,
|
2248
|
+
return_fig=return_fig,
|
2249
|
+
ax=ax,
|
2250
|
+
show=show,
|
2251
|
+
save=save,
|
2252
|
+
**kwargs,
|
2253
|
+
)
|
2254
|
+
|
1116
2255
|
|
1117
2256
|
def get_a(
|
1118
2257
|
tree: tt.tree,
|
@@ -1242,7 +2381,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
|
|
1242
2381
|
|
1243
2382
|
|
1244
2383
|
def get_a_2(
|
1245
|
-
tree:
|
2384
|
+
tree: Tree,
|
1246
2385
|
leaf_order: list[str] = None,
|
1247
2386
|
node_order: list[str] = None,
|
1248
2387
|
) -> tuple[np.ndarray, int]:
|
@@ -1263,6 +2402,13 @@ def get_a_2(
|
|
1263
2402
|
T
|
1264
2403
|
number of nodes in the tree, excluding the root node
|
1265
2404
|
"""
|
2405
|
+
try:
|
2406
|
+
import ete3 as ete
|
2407
|
+
except ImportError:
|
2408
|
+
raise ImportError(
|
2409
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2410
|
+
) from None
|
2411
|
+
|
1266
2412
|
n_tips = len(tree.get_leaves())
|
1267
2413
|
n_nodes = len(tree.get_descendants())
|
1268
2414
|
|
@@ -1292,7 +2438,7 @@ def get_a_2(
|
|
1292
2438
|
return A_, n_nodes
|
1293
2439
|
|
1294
2440
|
|
1295
|
-
def collapse_singularities_2(tree:
|
2441
|
+
def collapse_singularities_2(tree: Tree) -> Tree:
|
1296
2442
|
"""Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
|
1297
2443
|
|
1298
2444
|
Args:
|
@@ -1327,10 +2473,10 @@ def linkage_to_newick(
|
|
1327
2473
|
|
1328
2474
|
def build_newick(node, newick, parentdist, leaf_names):
|
1329
2475
|
if node.is_leaf():
|
1330
|
-
return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
|
2476
|
+
return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
|
1331
2477
|
else:
|
1332
2478
|
if len(newick) > 0:
|
1333
|
-
newick = f"):{(parentdist - node.dist)/2}{newick}"
|
2479
|
+
newick = f"):{(parentdist - node.dist) / 2}{newick}"
|
1334
2480
|
else:
|
1335
2481
|
newick = ");"
|
1336
2482
|
newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
|
@@ -1363,14 +2509,15 @@ def import_tree(
|
|
1363
2509
|
|
1364
2510
|
Args:
|
1365
2511
|
data: A tascCODA-compatible data object.
|
1366
|
-
modality_1: If `data` is MuData,
|
1367
|
-
modality_2: If `data` is MuData,
|
2512
|
+
modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object. Defaults to None.
|
2513
|
+
modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object. Defaults to None.
|
1368
2514
|
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
|
1369
2515
|
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
|
1370
2516
|
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
|
1371
|
-
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
1372
|
-
|
1373
|
-
|
2517
|
+
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
2518
|
+
Defaults to True.
|
2519
|
+
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
2520
|
+
If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
|
1374
2521
|
|
1375
2522
|
Returns:
|
1376
2523
|
Updates data with the following:
|
@@ -1379,6 +2526,13 @@ def import_tree(
|
|
1379
2526
|
|
1380
2527
|
tree: A ete3 tree object.
|
1381
2528
|
"""
|
2529
|
+
try:
|
2530
|
+
import ete3 as ete
|
2531
|
+
except ImportError:
|
2532
|
+
raise ImportError(
|
2533
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2534
|
+
) from None
|
2535
|
+
|
1382
2536
|
if isinstance(data, MuData):
|
1383
2537
|
try:
|
1384
2538
|
data_1 = data[modality_1]
|
@@ -1443,16 +2597,17 @@ def from_scanpy(
|
|
1443
2597
|
|
1444
2598
|
The anndata object needs to have a column in adata.obs that contains the cell type assignment.
|
1445
2599
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
1446
|
-
Further covariates (e.g. subject age) can either be specified via
|
2600
|
+
Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
|
1447
2601
|
|
1448
|
-
NOTE: The order of samples in the returned dataset is determined by the first
|
2602
|
+
NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
|
1449
2603
|
|
1450
2604
|
Args:
|
1451
2605
|
adata: An anndata object from scanpy
|
1452
2606
|
cell_type_identifier: column name in adata.obs that specifies the cell types
|
1453
2607
|
sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
|
1454
2608
|
covariate_uns: key for adata.uns, where covariate values are stored
|
1455
|
-
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2609
|
+
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2610
|
+
Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
|
1456
2611
|
covariate_df: DataFrame with covariates
|
1457
2612
|
|
1458
2613
|
Returns:
|
@@ -1461,50 +2616,40 @@ def from_scanpy(
|
|
1461
2616
|
if isinstance(sample_identifier, str):
|
1462
2617
|
sample_identifier = [sample_identifier]
|
1463
2618
|
|
1464
|
-
if
|
1465
|
-
covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
|
1466
|
-
else:
|
1467
|
-
covariate_obs = sample_identifier # type: ignore
|
1468
|
-
|
1469
|
-
# join sample identifiers
|
1470
|
-
if isinstance(sample_identifier, list):
|
2619
|
+
if len(sample_identifier) > 1:
|
1471
2620
|
adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
|
1472
2621
|
sample_identifier = "scCODA_sample_id"
|
2622
|
+
else:
|
2623
|
+
sample_identifier = sample_identifier[0]
|
1473
2624
|
|
1474
2625
|
# get cell type counts
|
1475
|
-
|
1476
|
-
|
1477
|
-
count_data = count_data.fillna(0)
|
2626
|
+
ct_count_data = pd.crosstab(adata.obs[sample_identifier], adata.obs[cell_type_identifier])
|
2627
|
+
ct_count_data = ct_count_data.fillna(0)
|
1478
2628
|
|
1479
2629
|
# get covariates from different sources
|
1480
|
-
covariate_df_ = pd.DataFrame(index=
|
1481
|
-
|
1482
|
-
if covariate_df is None and covariate_obs is None and covariate_uns is None:
|
1483
|
-
print("No covariate information specified!")
|
2630
|
+
covariate_df_ = pd.DataFrame(index=ct_count_data.index)
|
1484
2631
|
|
1485
2632
|
if covariate_uns is not None:
|
1486
|
-
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
|
1487
|
-
covariate_df_ =
|
2633
|
+
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
|
2634
|
+
covariate_df_ = covariate_df_.join(covariate_df_uns, how="left")
|
1488
2635
|
|
1489
|
-
if covariate_obs
|
1490
|
-
|
1491
|
-
|
1492
|
-
print(f"Covariate {c} has non-unique values! Skipping...")
|
1493
|
-
covariate_obs.remove(c)
|
2636
|
+
if covariate_obs:
|
2637
|
+
is_unique = adata.obs.groupby(sample_identifier, observed=True).transform(lambda x: x.nunique() == 1)
|
2638
|
+
unique_covariates = is_unique.columns[is_unique.all()].tolist()
|
1494
2639
|
|
1495
|
-
|
1496
|
-
|
2640
|
+
if len(unique_covariates) < len(covariate_obs):
|
2641
|
+
skipped = set(covariate_obs) - set(unique_covariates)
|
2642
|
+
print(f"[bold yellow]Covariates {skipped} have non-unique values! Skipping...")
|
2643
|
+
if unique_covariates:
|
2644
|
+
covariate_df_obs = adata.obs.groupby(sample_identifier, observed=True).first()[unique_covariates]
|
2645
|
+
covariate_df_ = covariate_df_.join(covariate_df_obs, how="left")
|
1497
2646
|
|
1498
2647
|
if covariate_df is not None:
|
1499
|
-
if
|
1500
|
-
raise ValueError("
|
1501
|
-
|
1502
|
-
covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
|
1503
|
-
|
1504
|
-
covariate_df_.index = covariate_df_.index.astype(str)
|
2648
|
+
if not covariate_df.index.equals(ct_count_data.index):
|
2649
|
+
raise ValueError("AnnData sample names and covariate_df index do not have the same elements!")
|
2650
|
+
covariate_df_ = covariate_df_.join(covariate_df, how="left")
|
1505
2651
|
|
1506
|
-
|
1507
|
-
var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
|
2652
|
+
var_dat = ct_count_data.sum(axis=0).rename("n_cells").to_frame()
|
1508
2653
|
var_dat.index = var_dat.index.astype(str)
|
1509
2654
|
|
1510
|
-
return AnnData(X=
|
2655
|
+
return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)
|