pertpy 0.10.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +1 -3
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +133 -25
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +138 -44
- pertpy/preprocessing/_guide_rna_mixture.py +17 -19
- pertpy/tools/__init__.py +4 -3
- pertpy/tools/_augur.py +106 -98
- pertpy/tools/_cinemaot.py +74 -114
- pertpy/tools/_coda/_base_coda.py +134 -148
- pertpy/tools/_coda/_sccoda.py +69 -70
- pertpy/tools/_coda/_tasccoda.py +74 -80
- pertpy/tools/_dialogue.py +48 -41
- pertpy/tools/_differential_gene_expression/_base.py +21 -31
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +31 -46
- pertpy/tools/_enrichment.py +7 -22
- pertpy/tools/_milo.py +19 -15
- pertpy/tools/_mixscape.py +73 -75
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +12 -14
- pertpy/tools/_scgen/_scgen.py +16 -17
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/METADATA +42 -24
- pertpy-0.11.1.dist-info/RECORD +58 -0
- {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.10.0.dist-info/RECORD +0 -58
- {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/WHEEL +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -4,12 +4,10 @@ from abc import ABC, abstractmethod
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import TYPE_CHECKING, Literal
|
6
6
|
|
7
|
-
import arviz as az
|
8
7
|
import jax.numpy as jnp
|
9
8
|
import matplotlib.pyplot as plt
|
10
9
|
import numpy as np
|
11
10
|
import pandas as pd
|
12
|
-
import patsy as pt
|
13
11
|
import scanpy as sc
|
14
12
|
import seaborn as sns
|
15
13
|
from adjustText import adjust_text
|
@@ -33,7 +31,7 @@ if TYPE_CHECKING:
|
|
33
31
|
|
34
32
|
import numpyro as npy
|
35
33
|
import toytree as tt
|
36
|
-
from
|
34
|
+
from ete4 import Tree
|
37
35
|
from jax._src.typing import Array
|
38
36
|
from matplotlib.axes import Axes
|
39
37
|
from matplotlib.colors import Colormap
|
@@ -126,7 +124,9 @@ class CompositionalModel2(ABC):
|
|
126
124
|
sample_adata.X = sample_adata.X.astype(dtype)
|
127
125
|
|
128
126
|
# Build covariate matrix from R-like formula, save in obsm
|
129
|
-
|
127
|
+
import patsy
|
128
|
+
|
129
|
+
covariate_matrix = patsy.dmatrix(formula, sample_adata.obs)
|
130
130
|
covariate_names = covariate_matrix.design_info.column_names[1:]
|
131
131
|
sample_adata.obsm["covariate_matrix"] = np.array(covariate_matrix[:, 1:]).astype(dtype)
|
132
132
|
|
@@ -198,7 +198,7 @@ class CompositionalModel2(ABC):
|
|
198
198
|
*args,
|
199
199
|
**kwargs,
|
200
200
|
):
|
201
|
-
"""Background function that executes any numpyro MCMC algorithm and processes its results
|
201
|
+
"""Background function that executes any numpyro MCMC algorithm and processes its results.
|
202
202
|
|
203
203
|
Args:
|
204
204
|
sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
@@ -294,6 +294,8 @@ class CompositionalModel2(ABC):
|
|
294
294
|
num_warmup: Number of burn-in (warmup) samples.
|
295
295
|
rng_key: The rng state used.
|
296
296
|
copy: Return a copy instead of writing to adata.
|
297
|
+
*args: Additional args passed to numpyro NUTS
|
298
|
+
**kwargs: Additional kwargs passed to numpyro NUTS
|
297
299
|
|
298
300
|
Returns:
|
299
301
|
Calls `self.__run_mcmc`
|
@@ -347,6 +349,8 @@ class CompositionalModel2(ABC):
|
|
347
349
|
num_warmup: Number of burn-in (warmup) samples.
|
348
350
|
rng_key: The rng state used. If None, a random state will be selected.
|
349
351
|
copy: Return a copy instead of writing to adata.
|
352
|
+
*args: Additional args passed to numpyro HMC
|
353
|
+
**kwargs: Additional kwargs passed to numpyro HMC
|
350
354
|
|
351
355
|
Examples:
|
352
356
|
>>> import pertpy as pt
|
@@ -396,7 +400,8 @@ class CompositionalModel2(ABC):
|
|
396
400
|
self, sample_adata: AnnData, est_fdr: float = 0.05, *args, **kwargs
|
397
401
|
) -> tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
398
402
|
"""Generates summary dataframes for intercepts, effects and node-level effect (if using tree aggregation).
|
399
|
-
|
403
|
+
|
404
|
+
This function builds on and supports all functionalities from ``az.summary``.
|
400
405
|
|
401
406
|
Args:
|
402
407
|
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
@@ -405,7 +410,7 @@ class CompositionalModel2(ABC):
|
|
405
410
|
kwargs: Passed to ``az.summary``
|
406
411
|
|
407
412
|
Returns:
|
408
|
-
Tuple[
|
413
|
+
Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame] or Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame, :class:pandas.DataFrame]: Intercept, effect and node-level DataFrames
|
409
414
|
|
410
415
|
intercept_df
|
411
416
|
Summary of intercept parameters. Contains one row per cell type.
|
@@ -435,7 +440,7 @@ class CompositionalModel2(ABC):
|
|
435
440
|
- Delta: Decision boundary value - threshold of practical significance
|
436
441
|
- Is credible: Boolean indicator whether effect is credible
|
437
442
|
|
438
|
-
|
443
|
+
Examples:
|
439
444
|
>>> import pertpy as pt
|
440
445
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
441
446
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -456,6 +461,8 @@ class CompositionalModel2(ABC):
|
|
456
461
|
else:
|
457
462
|
raise ValueError("No valid model type!")
|
458
463
|
|
464
|
+
import arviz as az
|
465
|
+
|
459
466
|
summ = az.summary(
|
460
467
|
data=self.make_arviz(sample_adata, num_prior_samples=0, use_posterior_predictive=False),
|
461
468
|
var_names=var_names,
|
@@ -684,7 +691,7 @@ class CompositionalModel2(ABC):
|
|
684
691
|
|
685
692
|
if fdr < alpha:
|
686
693
|
# ceiling with 3 decimals precision
|
687
|
-
c = np.floor(c * 10**3) / 10**3
|
694
|
+
c = np.floor(c * 10**3) / 10**3 # noqa: PLW2901
|
688
695
|
return c, fdr
|
689
696
|
return 1.0, 0
|
690
697
|
|
@@ -737,7 +744,8 @@ class CompositionalModel2(ABC):
|
|
737
744
|
node_df: pd.DataFrame,
|
738
745
|
) -> pd.DataFrame:
|
739
746
|
"""Evaluation of MCMC results for node-level effect parameters. This function is only used within self.summary_prepare.
|
740
|
-
|
747
|
+
|
748
|
+
This function determines whether node-level effects are credible or not.
|
741
749
|
|
742
750
|
Args:
|
743
751
|
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
@@ -932,15 +940,15 @@ class CompositionalModel2(ABC):
|
|
932
940
|
)
|
933
941
|
console.print(table)
|
934
942
|
|
935
|
-
def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda"):
|
936
|
-
"""Get intercept dataframe as printed in the extended summary
|
943
|
+
def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
|
944
|
+
"""Get intercept dataframe as printed in the extended summary.
|
937
945
|
|
938
946
|
Args:
|
939
947
|
data: AnnData object or MuData object.
|
940
948
|
modality_key: If data is a MuData object, specify which modality to use.
|
941
949
|
|
942
950
|
Returns:
|
943
|
-
|
951
|
+
Intercept data frame.
|
944
952
|
|
945
953
|
Examples:
|
946
954
|
>>> import pertpy as pt
|
@@ -963,15 +971,15 @@ class CompositionalModel2(ABC):
|
|
963
971
|
|
964
972
|
return sample_adata.varm["intercept_df"]
|
965
973
|
|
966
|
-
def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda"):
|
967
|
-
"""Get effect dataframe as printed in the extended summary
|
974
|
+
def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
|
975
|
+
"""Get effect dataframe as printed in the extended summary.
|
968
976
|
|
969
977
|
Args:
|
970
978
|
data: AnnData object or MuData object.
|
971
979
|
modality_key: If data is a MuData object, specify which modality to use.
|
972
980
|
|
973
981
|
Returns:
|
974
|
-
|
982
|
+
Effect data frame.
|
975
983
|
|
976
984
|
Examples:
|
977
985
|
>>> import pertpy as pt
|
@@ -1005,15 +1013,15 @@ class CompositionalModel2(ABC):
|
|
1005
1013
|
|
1006
1014
|
return effect_df
|
1007
1015
|
|
1008
|
-
def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"):
|
1009
|
-
"""Get node effect dataframe as printed in the extended summary of a tascCODA model
|
1016
|
+
def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
|
1017
|
+
"""Get node effect dataframe as printed in the extended summary of a tascCODA model.
|
1010
1018
|
|
1011
1019
|
Args:
|
1012
1020
|
data: AnnData object or MuData object.
|
1013
1021
|
modality_key: If data is a MuData object, specify which modality to use.
|
1014
1022
|
|
1015
1023
|
Returns:
|
1016
|
-
|
1024
|
+
Node effect data frame.
|
1017
1025
|
|
1018
1026
|
Examples:
|
1019
1027
|
>>> import pertpy as pt
|
@@ -1030,7 +1038,6 @@ class CompositionalModel2(ABC):
|
|
1030
1038
|
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1031
1039
|
>>> node_effects = tasccoda.get_node_df(mdata)
|
1032
1040
|
"""
|
1033
|
-
|
1034
1041
|
if isinstance(data, MuData):
|
1035
1042
|
try:
|
1036
1043
|
sample_adata = data[modality_key]
|
@@ -1043,8 +1050,9 @@ class CompositionalModel2(ABC):
|
|
1043
1050
|
return sample_adata.uns["scCODA_params"]["node_df"]
|
1044
1051
|
|
1045
1052
|
def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
|
1046
|
-
"""Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level
|
1047
|
-
|
1053
|
+
"""Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level.
|
1054
|
+
|
1055
|
+
Note: Does not work for spike-and-slab LASSO selection method.
|
1048
1056
|
|
1049
1057
|
Args:
|
1050
1058
|
data: AnnData object or MuData object.
|
@@ -1079,7 +1087,8 @@ class CompositionalModel2(ABC):
|
|
1079
1087
|
|
1080
1088
|
def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
|
1081
1089
|
"""Decides which effects of the scCODA model are credible based on an adjustable inclusion probability threshold.
|
1082
|
-
|
1090
|
+
|
1091
|
+
Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method.
|
1083
1092
|
|
1084
1093
|
Args:
|
1085
1094
|
data: AnnData object or MuData object.
|
@@ -1087,7 +1096,7 @@ class CompositionalModel2(ABC):
|
|
1087
1096
|
est_fdr: Estimated false discovery rate. Must be between 0 and 1.
|
1088
1097
|
|
1089
1098
|
Returns:
|
1090
|
-
|
1099
|
+
Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
|
1091
1100
|
"""
|
1092
1101
|
if isinstance(data, MuData):
|
1093
1102
|
try:
|
@@ -1109,16 +1118,15 @@ class CompositionalModel2(ABC):
|
|
1109
1118
|
else:
|
1110
1119
|
_, eff_df = self.summary_prepare(sample_adata, est_fdr=est_fdr) # type: ignore
|
1111
1120
|
# otherwise, get pre-calculated DataFrames. Effect DataFrame is stitched together from varm
|
1121
|
+
elif model_type == "tree_agg" and select_type == "sslasso":
|
1122
|
+
eff_df = sample_adata.uns["scCODA_params"]["node_df"]
|
1112
1123
|
else:
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
eff_df.index = pd.MultiIndex.from_product(
|
1120
|
-
(covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
|
1121
|
-
)
|
1124
|
+
covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
|
1125
|
+
effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
|
1126
|
+
eff_df = pd.concat(effect_dfs)
|
1127
|
+
eff_df.index = pd.MultiIndex.from_product(
|
1128
|
+
(covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
|
1129
|
+
)
|
1122
1130
|
|
1123
1131
|
out = eff_df["Final Parameter"] != 0
|
1124
1132
|
out.rename("credible change")
|
@@ -1188,7 +1196,7 @@ class CompositionalModel2(ABC):
|
|
1188
1196
|
return ax
|
1189
1197
|
|
1190
1198
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1191
|
-
def plot_stacked_barplot( # pragma: no cover
|
1199
|
+
def plot_stacked_barplot( # pragma: no cover # noqa: D417
|
1192
1200
|
self,
|
1193
1201
|
data: AnnData | MuData,
|
1194
1202
|
feature_name: str,
|
@@ -1215,7 +1223,7 @@ class CompositionalModel2(ABC):
|
|
1215
1223
|
{common_plot_args}
|
1216
1224
|
|
1217
1225
|
Returns:
|
1218
|
-
If `return_fig` is `True`, returns the
|
1226
|
+
If `return_fig` is `True`, returns the Figure, otherwise `None`.
|
1219
1227
|
|
1220
1228
|
Examples:
|
1221
1229
|
>>> import pertpy as pt
|
@@ -1230,8 +1238,6 @@ class CompositionalModel2(ABC):
|
|
1230
1238
|
"""
|
1231
1239
|
if isinstance(data, MuData):
|
1232
1240
|
data = data[modality_key]
|
1233
|
-
if isinstance(data, AnnData):
|
1234
|
-
data = data
|
1235
1241
|
|
1236
1242
|
ct_names = data.var.index
|
1237
1243
|
|
@@ -1283,7 +1289,7 @@ class CompositionalModel2(ABC):
|
|
1283
1289
|
return None
|
1284
1290
|
|
1285
1291
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1286
|
-
def plot_effects_barplot( # pragma: no cover
|
1292
|
+
def plot_effects_barplot( # pragma: no cover # noqa: D417
|
1287
1293
|
self,
|
1288
1294
|
data: AnnData | MuData,
|
1289
1295
|
*,
|
@@ -1340,8 +1346,7 @@ class CompositionalModel2(ABC):
|
|
1340
1346
|
args_barplot = {}
|
1341
1347
|
if isinstance(data, MuData):
|
1342
1348
|
data = data[modality_key]
|
1343
|
-
|
1344
|
-
data = data
|
1349
|
+
|
1345
1350
|
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
1346
1351
|
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
1347
1352
|
if covariates is not None:
|
@@ -1372,18 +1377,16 @@ class CompositionalModel2(ABC):
|
|
1372
1377
|
|
1373
1378
|
plot_df = plot_df.reset_index()
|
1374
1379
|
|
1375
|
-
if len(covariate_names_zero) != 0:
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1386
|
-
plot_df = plot_df.sort_values(["covariate_"])
|
1380
|
+
if len(covariate_names_zero) != 0 and plot_facets and plot_zero_covariate and not plot_zero_cell_type:
|
1381
|
+
for covariate_name_zero in covariate_names_zero:
|
1382
|
+
new_row = {
|
1383
|
+
"Covariate": covariate_name_zero,
|
1384
|
+
"Cell Type": "zero",
|
1385
|
+
"value": 0,
|
1386
|
+
}
|
1387
|
+
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
1388
|
+
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1389
|
+
plot_df = plot_df.sort_values(["covariate_"])
|
1387
1390
|
if not plot_zero_cell_type:
|
1388
1391
|
cell_type_names_zero = [
|
1389
1392
|
name
|
@@ -1427,9 +1430,8 @@ class CompositionalModel2(ABC):
|
|
1427
1430
|
ax.set_title(covariate_names[i])
|
1428
1431
|
if len(ax.get_xticklabels()) < 5:
|
1429
1432
|
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
1430
|
-
if len(ax.get_xticklabels()) == 1:
|
1431
|
-
|
1432
|
-
ax.set_xticks([])
|
1433
|
+
if len(ax.get_xticklabels()) == 1 and ax.get_xticklabels()[0]._text == "zero":
|
1434
|
+
ax.set_xticks([])
|
1433
1435
|
|
1434
1436
|
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
1435
1437
|
else:
|
@@ -1460,6 +1462,7 @@ class CompositionalModel2(ABC):
|
|
1460
1462
|
ax=ax,
|
1461
1463
|
)
|
1462
1464
|
cell_types = pd.unique(plot_df["Cell Type"])
|
1465
|
+
ax.set_xticks(cell_types)
|
1463
1466
|
ax.set_xticklabels(cell_types, rotation=90)
|
1464
1467
|
|
1465
1468
|
if return_fig and plot_facets:
|
@@ -1470,7 +1473,7 @@ class CompositionalModel2(ABC):
|
|
1470
1473
|
return None
|
1471
1474
|
|
1472
1475
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1473
|
-
def plot_boxplots( # pragma: no cover
|
1476
|
+
def plot_boxplots( # pragma: no cover # noqa: D417
|
1474
1477
|
self,
|
1475
1478
|
data: AnnData | MuData,
|
1476
1479
|
feature_name: str,
|
@@ -1532,8 +1535,7 @@ class CompositionalModel2(ABC):
|
|
1532
1535
|
args_swarmplot = {}
|
1533
1536
|
if isinstance(data, MuData):
|
1534
1537
|
data = data[modality_key]
|
1535
|
-
|
1536
|
-
data = data
|
1538
|
+
|
1537
1539
|
# y scale transformations
|
1538
1540
|
if y_scale == "relative":
|
1539
1541
|
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
@@ -1607,10 +1609,7 @@ class CompositionalModel2(ABC):
|
|
1607
1609
|
)
|
1608
1610
|
|
1609
1611
|
if add_dots:
|
1610
|
-
if "hue" in args_swarmplot
|
1611
|
-
hue = args_swarmplot.pop("hue")
|
1612
|
-
else:
|
1613
|
-
hue = None
|
1612
|
+
hue = args_swarmplot.pop("hue") if "hue" in args_swarmplot else None
|
1614
1613
|
|
1615
1614
|
if hue is None:
|
1616
1615
|
g.map(
|
@@ -1675,6 +1674,7 @@ class CompositionalModel2(ABC):
|
|
1675
1674
|
)
|
1676
1675
|
|
1677
1676
|
cell_types = pd.unique(plot_df["Cell type"])
|
1677
|
+
ax.set_xticks(cell_types)
|
1678
1678
|
ax.set_xticklabels(cell_types, rotation=90)
|
1679
1679
|
|
1680
1680
|
if show_legend:
|
@@ -1702,7 +1702,7 @@ class CompositionalModel2(ABC):
|
|
1702
1702
|
return None
|
1703
1703
|
|
1704
1704
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1705
|
-
def plot_rel_abundance_dispersion_plot( # pragma: no cover
|
1705
|
+
def plot_rel_abundance_dispersion_plot( # pragma: no cover # noqa: D417
|
1706
1706
|
self,
|
1707
1707
|
data: AnnData | MuData,
|
1708
1708
|
*,
|
@@ -1750,8 +1750,7 @@ class CompositionalModel2(ABC):
|
|
1750
1750
|
"""
|
1751
1751
|
if isinstance(data, MuData):
|
1752
1752
|
data = data[modality_key]
|
1753
|
-
|
1754
|
-
data = data
|
1753
|
+
|
1755
1754
|
if ax is None:
|
1756
1755
|
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1757
1756
|
|
@@ -1823,13 +1822,13 @@ class CompositionalModel2(ABC):
|
|
1823
1822
|
return None
|
1824
1823
|
|
1825
1824
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1826
|
-
def plot_draw_tree( # pragma: no cover
|
1825
|
+
def plot_draw_tree( # pragma: no cover # noqa: D417
|
1827
1826
|
self,
|
1828
1827
|
data: AnnData | MuData,
|
1829
1828
|
*,
|
1830
1829
|
modality_key: str = "coda",
|
1831
|
-
tree: str = "tree", # Also type
|
1832
|
-
tight_text: bool
|
1830
|
+
tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
|
1831
|
+
tight_text: bool = False,
|
1833
1832
|
show_scale: bool | None = False,
|
1834
1833
|
units: Literal["px", "mm", "in"] | None = "px",
|
1835
1834
|
figsize: tuple[float, float] | None = (None, None),
|
@@ -1837,12 +1836,12 @@ class CompositionalModel2(ABC):
|
|
1837
1836
|
save: str | bool = False,
|
1838
1837
|
return_fig: bool = False,
|
1839
1838
|
) -> Tree | None:
|
1840
|
-
"""Plot a tree using input
|
1839
|
+
"""Plot a tree using input ete4 tree object.
|
1841
1840
|
|
1842
1841
|
Args:
|
1843
1842
|
data: AnnData object or MuData object.
|
1844
1843
|
modality_key: If data is a MuData object, specify which modality to use.
|
1845
|
-
tree: A
|
1844
|
+
tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
|
1846
1845
|
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1847
1846
|
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1848
1847
|
show_scale: Include the scale legend in the tree image or not.
|
@@ -1853,7 +1852,7 @@ class CompositionalModel2(ABC):
|
|
1853
1852
|
{common_plot_args}
|
1854
1853
|
|
1855
1854
|
Returns:
|
1856
|
-
Depending on `
|
1855
|
+
Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`) or plot the tree inline (`save = False`)
|
1857
1856
|
|
1858
1857
|
Examples:
|
1859
1858
|
>>> import pertpy as pt
|
@@ -1874,7 +1873,8 @@ class CompositionalModel2(ABC):
|
|
1874
1873
|
.. image:: /_static/docstring_previews/tasccoda_draw_tree.png
|
1875
1874
|
"""
|
1876
1875
|
try:
|
1877
|
-
from
|
1876
|
+
from ete4 import Tree
|
1877
|
+
from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
|
1878
1878
|
except ImportError:
|
1879
1879
|
raise ImportError(
|
1880
1880
|
"To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
|
@@ -1882,8 +1882,6 @@ class CompositionalModel2(ABC):
|
|
1882
1882
|
|
1883
1883
|
if isinstance(data, MuData):
|
1884
1884
|
data = data[modality_key]
|
1885
|
-
if isinstance(data, AnnData):
|
1886
|
-
data = data
|
1887
1885
|
if isinstance(tree, str):
|
1888
1886
|
tree = data.uns[tree]
|
1889
1887
|
|
@@ -1896,7 +1894,7 @@ class CompositionalModel2(ABC):
|
|
1896
1894
|
tree_style.layout_fn = my_layout
|
1897
1895
|
tree_style.show_scale = show_scale
|
1898
1896
|
|
1899
|
-
if save
|
1897
|
+
if save:
|
1900
1898
|
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1901
1899
|
if return_fig:
|
1902
1900
|
return tree, tree_style
|
@@ -1904,13 +1902,13 @@ class CompositionalModel2(ABC):
|
|
1904
1902
|
return None
|
1905
1903
|
|
1906
1904
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1907
|
-
def plot_draw_effects( # pragma: no cover
|
1905
|
+
def plot_draw_effects( # pragma: no cover # noqa: D417
|
1908
1906
|
self,
|
1909
1907
|
data: AnnData | MuData,
|
1910
1908
|
covariate: str,
|
1911
1909
|
*,
|
1912
1910
|
modality_key: str = "coda",
|
1913
|
-
tree: str = "tree", # Also type
|
1911
|
+
tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
|
1914
1912
|
show_legend: bool | None = None,
|
1915
1913
|
show_leaf_effects: bool | None = False,
|
1916
1914
|
tight_text: bool | None = False,
|
@@ -1927,7 +1925,7 @@ class CompositionalModel2(ABC):
|
|
1927
1925
|
data: AnnData object or MuData object.
|
1928
1926
|
covariate: The covariate, whose effects should be plotted.
|
1929
1927
|
modality_key: If data is a MuData object, specify which modality to use.
|
1930
|
-
tree: A
|
1928
|
+
tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
|
1931
1929
|
show_legend: If show legend of nodes significant effects or not.
|
1932
1930
|
Defaults to False if show_leaf_effects is True.
|
1933
1931
|
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
@@ -1941,8 +1939,8 @@ class CompositionalModel2(ABC):
|
|
1941
1939
|
{common_plot_args}
|
1942
1940
|
|
1943
1941
|
Returns:
|
1944
|
-
|
1945
|
-
or
|
1942
|
+
Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`)
|
1943
|
+
or plot the tree inline (`save = False`).
|
1946
1944
|
|
1947
1945
|
Examples:
|
1948
1946
|
>>> import pertpy as pt
|
@@ -1963,7 +1961,8 @@ class CompositionalModel2(ABC):
|
|
1963
1961
|
.. image:: /_static/docstring_previews/tasccoda_draw_effects.png
|
1964
1962
|
"""
|
1965
1963
|
try:
|
1966
|
-
from
|
1964
|
+
from ete4 import Tree
|
1965
|
+
from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
|
1967
1966
|
except ImportError:
|
1968
1967
|
raise ImportError(
|
1969
1968
|
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
@@ -1971,8 +1970,6 @@ class CompositionalModel2(ABC):
|
|
1971
1970
|
|
1972
1971
|
if isinstance(data, MuData):
|
1973
1972
|
data = data[modality_key]
|
1974
|
-
if isinstance(data, AnnData):
|
1975
|
-
data = data
|
1976
1973
|
if show_legend is None:
|
1977
1974
|
show_legend = not show_leaf_effects
|
1978
1975
|
elif show_legend:
|
@@ -2003,18 +2000,18 @@ class CompositionalModel2(ABC):
|
|
2003
2000
|
n.set_style(nstyle)
|
2004
2001
|
if n.name in node_effs.index:
|
2005
2002
|
e = node_effs.loc[n.name, "Final Parameter"]
|
2006
|
-
n.
|
2003
|
+
n.add_prop("node_effect", e)
|
2007
2004
|
else:
|
2008
|
-
n.
|
2005
|
+
n.add_prop("node_effect", 0)
|
2009
2006
|
if n.name in leaf_effs.index:
|
2010
2007
|
e = leaf_effs.loc[n.name, "Effect"]
|
2011
|
-
n.
|
2008
|
+
n.add_prop("leaf_effect", e)
|
2012
2009
|
else:
|
2013
|
-
n.
|
2010
|
+
n.add_prop("leaf_effect", 0)
|
2014
2011
|
|
2015
2012
|
# Scale effect values to get nice node sizes
|
2016
|
-
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
2017
|
-
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
2013
|
+
eff_max = np.max([np.abs(n.props.get("node_effect")) for n in tree2.traverse()])
|
2014
|
+
leaf_eff_max = np.max([np.abs(n.props.get("leaf_effect")) for n in tree2.traverse()])
|
2018
2015
|
|
2019
2016
|
def my_layout(node):
|
2020
2017
|
text_face = TextFace(node.name, tight_text=tight_text)
|
@@ -2022,10 +2019,10 @@ class CompositionalModel2(ABC):
|
|
2022
2019
|
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
2023
2020
|
|
2024
2021
|
# if node.is_leaf():
|
2025
|
-
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
2026
|
-
if np.sign(node.node_effect) == 1:
|
2022
|
+
size = (np.abs(node.props.get("node_effect")) * 10 / eff_max) if node.props.get("node_effect") != 0 else 0
|
2023
|
+
if np.sign(node.props.get("node_effect")) == 1:
|
2027
2024
|
color = "blue"
|
2028
|
-
elif np.sign(node.node_effect) == -1:
|
2025
|
+
elif np.sign(node.props.get("node_effect")) == -1:
|
2029
2026
|
color = "red"
|
2030
2027
|
else:
|
2031
2028
|
color = "cyan"
|
@@ -2061,13 +2058,13 @@ class CompositionalModel2(ABC):
|
|
2061
2058
|
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
2062
2059
|
|
2063
2060
|
if show_leaf_effects:
|
2064
|
-
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf
|
2061
|
+
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf]
|
2065
2062
|
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
2066
2063
|
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
2067
2064
|
|
2068
2065
|
dir_path = Path.cwd()
|
2069
2066
|
dir_path = Path(dir_path / "tree_effect.png")
|
2070
|
-
tree2.render(dir_path, tree_style=tree_style, units="in")
|
2067
|
+
tree2.render(dir_path.as_posix(), tree_style=tree_style, units="in")
|
2071
2068
|
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
2072
2069
|
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
2073
2070
|
img = mpimg.imread(dir_path)
|
@@ -2098,7 +2095,7 @@ class CompositionalModel2(ABC):
|
|
2098
2095
|
return None
|
2099
2096
|
|
2100
2097
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
2101
|
-
def plot_effects_umap( # pragma: no cover
|
2098
|
+
def plot_effects_umap( # pragma: no cover # noqa: D417
|
2102
2099
|
self,
|
2103
2100
|
mdata: MuData,
|
2104
2101
|
effect_name: str | list | None,
|
@@ -2211,7 +2208,7 @@ class CompositionalModel2(ABC):
|
|
2211
2208
|
def get_a(
|
2212
2209
|
tree: tt.core.ToyTree,
|
2213
2210
|
) -> tuple[np.ndarray, int]:
|
2214
|
-
"""Calculate ancestor matrix from a toytree tree
|
2211
|
+
"""Calculate ancestor matrix from a toytree tree.
|
2215
2212
|
|
2216
2213
|
Args:
|
2217
2214
|
tree: A toytree tree object.
|
@@ -2264,16 +2261,14 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
|
|
2264
2261
|
A_T = A.T
|
2265
2262
|
unq, count = np.unique(A_T, axis=0, return_counts=True)
|
2266
2263
|
|
2267
|
-
repeated_idx = []
|
2268
|
-
for repeated_group in unq[count > 1]:
|
2269
|
-
repeated_idx.append(np.argwhere(np.all(A_T == repeated_group, axis=1)).ravel())
|
2264
|
+
repeated_idx = [np.argwhere(np.all(repeated_group == A_T, axis=1)).ravel() for repeated_group in unq[count > 1]]
|
2270
2265
|
|
2271
2266
|
nodes_to_delete = [i for idx in repeated_idx for i in idx[1:]]
|
2272
2267
|
|
2273
2268
|
# _coords.update() scrambles the idx of leaves. Therefore, keep track of it here
|
2274
2269
|
tree_new = tree.copy()
|
2275
2270
|
for node in tree_new.treenode.traverse():
|
2276
|
-
node.
|
2271
|
+
node.add_prop("idx_orig", node.idx)
|
2277
2272
|
|
2278
2273
|
for n in nodes_to_delete:
|
2279
2274
|
node = tree_new.idx_dict[n]
|
@@ -2289,21 +2284,16 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
|
|
2289
2284
|
return tree_new
|
2290
2285
|
|
2291
2286
|
|
2292
|
-
def traverse(df_, a, i, innerl):
|
2293
|
-
"""
|
2294
|
-
|
2295
|
-
Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram
|
2287
|
+
def traverse(df_: pd.DataFrame, a: str, i: int, innerl: bool) -> str:
|
2288
|
+
"""Helper function for df2newick.
|
2289
|
+
|
2290
|
+
Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram.
|
2296
2291
|
"""
|
2297
2292
|
if i + 1 < df_.shape[1]:
|
2298
2293
|
a_inner = pd.unique(df_.loc[np.where(df_.iloc[:, i] == a)].iloc[:, i + 1])
|
2299
2294
|
|
2300
|
-
desc = []
|
2301
|
-
|
2302
|
-
desc.append(traverse(df_, b, i + 1, innerl))
|
2303
|
-
if innerl:
|
2304
|
-
il = a
|
2305
|
-
else:
|
2306
|
-
il = ""
|
2295
|
+
desc = [traverse(df_, b, i + 1, innerl) for b in a_inner]
|
2296
|
+
il = a if innerl else ""
|
2307
2297
|
out = f"({','.join(desc)}){il}"
|
2308
2298
|
else:
|
2309
2299
|
out = a
|
@@ -2327,9 +2317,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
|
|
2327
2317
|
df_tax = df.loc[:, [x for x in levels if x in df.columns]]
|
2328
2318
|
|
2329
2319
|
alevel = pd.unique(df_tax.iloc[:, 0])
|
2330
|
-
strs = []
|
2331
|
-
for a in alevel:
|
2332
|
-
strs.append(traverse(df_tax, a, 0, inner_label))
|
2320
|
+
strs = [traverse(df_tax, a, 0, inner_label) for a in alevel]
|
2333
2321
|
|
2334
2322
|
newick = f"({','.join(strs)});"
|
2335
2323
|
return newick
|
@@ -2340,10 +2328,10 @@ def get_a_2(
|
|
2340
2328
|
leaf_order: list[str] = None,
|
2341
2329
|
node_order: list[str] = None,
|
2342
2330
|
) -> tuple[np.ndarray, int]:
|
2343
|
-
"""Calculate ancestor matrix from a
|
2331
|
+
"""Calculate ancestor matrix from a ete4 tree.
|
2344
2332
|
|
2345
2333
|
Args:
|
2346
|
-
tree: A
|
2334
|
+
tree: A ete4 tree object.
|
2347
2335
|
leaf_order: List of leaf names how they should appear as the rows of the ancestor matrix.
|
2348
2336
|
If None, the ordering will be as in `tree.iter_leaves()`
|
2349
2337
|
node_order: List of node names how they should appear as the columns of the ancestor matrix
|
@@ -2358,29 +2346,29 @@ def get_a_2(
|
|
2358
2346
|
number of nodes in the tree, excluding the root node
|
2359
2347
|
"""
|
2360
2348
|
try:
|
2361
|
-
import
|
2349
|
+
import ete4 as ete
|
2362
2350
|
except ImportError:
|
2363
2351
|
raise ImportError(
|
2364
2352
|
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2365
2353
|
) from None
|
2366
2354
|
|
2367
|
-
n_tips = len(tree.
|
2368
|
-
n_nodes = len(tree.
|
2355
|
+
n_tips = len(list(tree.leaves()))
|
2356
|
+
n_nodes = len(list(tree.descendants()))
|
2369
2357
|
|
2370
|
-
node_names = [n.name for n in tree.
|
2358
|
+
node_names = [n.name for n in tree.descendants()]
|
2371
2359
|
duplicates = [x for x in node_names if node_names.count(x) > 1]
|
2372
2360
|
if len(duplicates) > 0:
|
2373
2361
|
raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")
|
2374
2362
|
|
2375
2363
|
# Initialize ancestor matrix
|
2376
2364
|
A_ = pd.DataFrame(np.zeros((n_tips, n_nodes)))
|
2377
|
-
A_.index = tree.
|
2378
|
-
A_.columns = [n.name for n in tree.
|
2365
|
+
A_.index = tree.leaf_names()
|
2366
|
+
A_.columns = [n.name for n in tree.descendants()]
|
2379
2367
|
|
2380
2368
|
# Fill in 1's for all connections
|
2381
|
-
for node in tree.
|
2382
|
-
for leaf in tree.
|
2383
|
-
if leaf in node.
|
2369
|
+
for node in tree.descendants():
|
2370
|
+
for leaf in tree.leaves():
|
2371
|
+
if leaf in node.leaves():
|
2384
2372
|
A_.loc[leaf.name, node.name] = 1
|
2385
2373
|
|
2386
2374
|
# Order rows and columns
|
@@ -2394,15 +2382,15 @@ def get_a_2(
|
|
2394
2382
|
|
2395
2383
|
|
2396
2384
|
def collapse_singularities_2(tree: Tree) -> Tree:
|
2397
|
-
"""Collapses (deletes) nodes in a
|
2385
|
+
"""Collapses (deletes) nodes in a ete4 tree that are singularities (have only one child).
|
2398
2386
|
|
2399
2387
|
Args:
|
2400
|
-
tree: A
|
2388
|
+
tree: A ete4 tree object
|
2401
2389
|
|
2402
2390
|
Returns:
|
2403
|
-
A
|
2391
|
+
A ete4 tree without singularities.
|
2404
2392
|
"""
|
2405
|
-
for node in tree.
|
2393
|
+
for node in tree.descendants():
|
2406
2394
|
if len(node.get_children()) == 1:
|
2407
2395
|
node.delete()
|
2408
2396
|
|
@@ -2427,13 +2415,10 @@ def linkage_to_newick(
|
|
2427
2415
|
tree = sp_hierarchy.to_tree(Z, False)
|
2428
2416
|
|
2429
2417
|
def build_newick(node, newick, parentdist, leaf_names):
|
2430
|
-
if node.is_leaf
|
2418
|
+
if node.is_leaf:
|
2431
2419
|
return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
|
2432
2420
|
else:
|
2433
|
-
if len(newick) > 0
|
2434
|
-
newick = f"):{(parentdist - node.dist) / 2}{newick}"
|
2435
|
-
else:
|
2436
|
-
newick = ");"
|
2421
|
+
newick = f"):{(parentdist - node.dist) / 2}{newick}" if len(newick) > 0 else ");"
|
2437
2422
|
newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
|
2438
2423
|
newick = build_newick(node.get_right(), f",{newick}", node.dist, leaf_names)
|
2439
2424
|
newick = f"({newick}"
|
@@ -2478,10 +2463,10 @@ def import_tree(
|
|
2478
2463
|
|
2479
2464
|
See `key_added` parameter description for the storage path of tree.
|
2480
2465
|
|
2481
|
-
tree: A
|
2466
|
+
tree: A ete4 tree object.
|
2482
2467
|
"""
|
2483
2468
|
try:
|
2484
|
-
import
|
2469
|
+
import ete4 as ete
|
2485
2470
|
except ImportError:
|
2486
2471
|
raise ImportError(
|
2487
2472
|
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
@@ -2506,32 +2491,33 @@ def import_tree(
|
|
2506
2491
|
data_1.uns["dendrogram_cell_label"]["linkage"],
|
2507
2492
|
labels=data_1.uns["dendrogram_cell_label"]["categories_ordered"],
|
2508
2493
|
)
|
2509
|
-
tree = ete.Tree(newick,
|
2494
|
+
tree = ete.Tree(newick, parser=1)
|
2510
2495
|
node_id = 0
|
2511
|
-
for n in tree.
|
2512
|
-
if not n.is_leaf
|
2496
|
+
for n in tree.descendants():
|
2497
|
+
if not n.is_leaf:
|
2513
2498
|
n.name = str(node_id)
|
2514
2499
|
node_id += 1
|
2515
2500
|
elif levels_orig is not None:
|
2516
2501
|
newick = df2newick(data_1.obs.reset_index(), levels=levels_orig)
|
2517
|
-
tree = ete.Tree(newick,
|
2502
|
+
tree = ete.Tree(newick, parser=8)
|
2503
|
+
|
2518
2504
|
if add_level_name:
|
2519
|
-
for n in tree.
|
2520
|
-
if not n.is_leaf
|
2521
|
-
dist = n.get_distance(n, tree)
|
2505
|
+
for n in tree.descendants():
|
2506
|
+
if not n.is_leaf:
|
2507
|
+
dist = n.get_distance(n, tree, topological=True)
|
2522
2508
|
n.name = f"{levels_orig[int(dist) - 1]}_{n.name}"
|
2523
2509
|
elif levels_agg is not None:
|
2524
2510
|
newick = df2newick(data_2.var.reset_index(), levels=levels_agg)
|
2525
|
-
tree = ete.Tree(newick,
|
2511
|
+
tree = ete.Tree(newick, parser=8)
|
2526
2512
|
if add_level_name:
|
2527
|
-
for n in tree.
|
2528
|
-
if not n.is_leaf
|
2529
|
-
dist = n.get_distance(n, tree)
|
2513
|
+
for n in tree.descendants():
|
2514
|
+
if not n.is_leaf:
|
2515
|
+
dist = n.get_distance(n, tree, topological=True)
|
2530
2516
|
n.name = f"{levels_agg[int(dist) - 1]}_{n.name}"
|
2531
2517
|
else:
|
2532
2518
|
raise ValueError("Either dendrogram_key, levels_orig or levels_agg must be specified!")
|
2533
2519
|
|
2534
|
-
node_names = [n.name for n in tree.
|
2520
|
+
node_names = [n.name for n in tree.descendants()]
|
2535
2521
|
duplicates = {x for x in node_names if node_names.count(x) > 1}
|
2536
2522
|
if len(duplicates) > 0:
|
2537
2523
|
raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")
|