pertpy 0.10.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +4 -3
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +134 -148
  16. pertpy/tools/_coda/_sccoda.py +69 -70
  17. pertpy/tools/_coda/_tasccoda.py +74 -80
  18. pertpy/tools/_dialogue.py +48 -41
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -46
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/METADATA +42 -24
  40. pertpy-0.11.1.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/WHEEL +0 -0
@@ -4,12 +4,10 @@ from abc import ABC, abstractmethod
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
7
- import arviz as az
8
7
  import jax.numpy as jnp
9
8
  import matplotlib.pyplot as plt
10
9
  import numpy as np
11
10
  import pandas as pd
12
- import patsy as pt
13
11
  import scanpy as sc
14
12
  import seaborn as sns
15
13
  from adjustText import adjust_text
@@ -33,7 +31,7 @@ if TYPE_CHECKING:
33
31
 
34
32
  import numpyro as npy
35
33
  import toytree as tt
36
- from ete3 import Tree
34
+ from ete4 import Tree
37
35
  from jax._src.typing import Array
38
36
  from matplotlib.axes import Axes
39
37
  from matplotlib.colors import Colormap
@@ -126,7 +124,9 @@ class CompositionalModel2(ABC):
126
124
  sample_adata.X = sample_adata.X.astype(dtype)
127
125
 
128
126
  # Build covariate matrix from R-like formula, save in obsm
129
- covariate_matrix = pt.dmatrix(formula, sample_adata.obs)
127
+ import patsy
128
+
129
+ covariate_matrix = patsy.dmatrix(formula, sample_adata.obs)
130
130
  covariate_names = covariate_matrix.design_info.column_names[1:]
131
131
  sample_adata.obsm["covariate_matrix"] = np.array(covariate_matrix[:, 1:]).astype(dtype)
132
132
 
@@ -198,7 +198,7 @@ class CompositionalModel2(ABC):
198
198
  *args,
199
199
  **kwargs,
200
200
  ):
201
- """Background function that executes any numpyro MCMC algorithm and processes its results
201
+ """Background function that executes any numpyro MCMC algorithm and processes its results.
202
202
 
203
203
  Args:
204
204
  sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -294,6 +294,8 @@ class CompositionalModel2(ABC):
294
294
  num_warmup: Number of burn-in (warmup) samples.
295
295
  rng_key: The rng state used.
296
296
  copy: Return a copy instead of writing to adata.
297
+ *args: Additional args passed to numpyro NUTS
298
+ **kwargs: Additional kwargs passed to numpyro NUTS
297
299
 
298
300
  Returns:
299
301
  Calls `self.__run_mcmc`
@@ -347,6 +349,8 @@ class CompositionalModel2(ABC):
347
349
  num_warmup: Number of burn-in (warmup) samples.
348
350
  rng_key: The rng state used. If None, a random state will be selected.
349
351
  copy: Return a copy instead of writing to adata.
352
+ *args: Additional args passed to numpyro HMC
353
+ **kwargs: Additional kwargs passed to numpyro HMC
350
354
 
351
355
  Examples:
352
356
  >>> import pertpy as pt
@@ -396,7 +400,8 @@ class CompositionalModel2(ABC):
396
400
  self, sample_adata: AnnData, est_fdr: float = 0.05, *args, **kwargs
397
401
  ) -> tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
398
402
  """Generates summary dataframes for intercepts, effects and node-level effect (if using tree aggregation).
399
- This function builds on and supports all functionalities from ``az.summary``.
403
+
404
+ This function builds on and supports all functionalities from ``az.summary``.
400
405
 
401
406
  Args:
402
407
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -405,7 +410,7 @@ class CompositionalModel2(ABC):
405
410
  kwargs: Passed to ``az.summary``
406
411
 
407
412
  Returns:
408
- Tuple[pd.DataFrame, pd.DataFrame] or Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: Intercept, effect and node-level DataFrames
413
+ Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame] or Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame, :class:pandas.DataFrame]: Intercept, effect and node-level DataFrames
409
414
 
410
415
  intercept_df
411
416
  Summary of intercept parameters. Contains one row per cell type.
@@ -435,7 +440,7 @@ class CompositionalModel2(ABC):
435
440
  - Delta: Decision boundary value - threshold of practical significance
436
441
  - Is credible: Boolean indicator whether effect is credible
437
442
 
438
- Examples:
443
+ Examples:
439
444
  >>> import pertpy as pt
440
445
  >>> haber_cells = pt.dt.haber_2017_regions()
441
446
  >>> sccoda = pt.tl.Sccoda()
@@ -456,6 +461,8 @@ class CompositionalModel2(ABC):
456
461
  else:
457
462
  raise ValueError("No valid model type!")
458
463
 
464
+ import arviz as az
465
+
459
466
  summ = az.summary(
460
467
  data=self.make_arviz(sample_adata, num_prior_samples=0, use_posterior_predictive=False),
461
468
  var_names=var_names,
@@ -684,7 +691,7 @@ class CompositionalModel2(ABC):
684
691
 
685
692
  if fdr < alpha:
686
693
  # ceiling with 3 decimals precision
687
- c = np.floor(c * 10**3) / 10**3
694
+ c = np.floor(c * 10**3) / 10**3 # noqa: PLW2901
688
695
  return c, fdr
689
696
  return 1.0, 0
690
697
 
@@ -737,7 +744,8 @@ class CompositionalModel2(ABC):
737
744
  node_df: pd.DataFrame,
738
745
  ) -> pd.DataFrame:
739
746
  """Evaluation of MCMC results for node-level effect parameters. This function is only used within self.summary_prepare.
740
- This function determines whether node-level effects are credible or not
747
+
748
+ This function determines whether node-level effects are credible or not.
741
749
 
742
750
  Args:
743
751
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -932,15 +940,15 @@ class CompositionalModel2(ABC):
932
940
  )
933
941
  console.print(table)
934
942
 
935
- def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda"):
936
- """Get intercept dataframe as printed in the extended summary
943
+ def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
944
+ """Get intercept dataframe as printed in the extended summary.
937
945
 
938
946
  Args:
939
947
  data: AnnData object or MuData object.
940
948
  modality_key: If data is a MuData object, specify which modality to use.
941
949
 
942
950
  Returns:
943
- pd.DataFrame: Intercept data frame.
951
+ Intercept data frame.
944
952
 
945
953
  Examples:
946
954
  >>> import pertpy as pt
@@ -963,15 +971,15 @@ class CompositionalModel2(ABC):
963
971
 
964
972
  return sample_adata.varm["intercept_df"]
965
973
 
966
- def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda"):
967
- """Get effect dataframe as printed in the extended summary
974
+ def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
975
+ """Get effect dataframe as printed in the extended summary.
968
976
 
969
977
  Args:
970
978
  data: AnnData object or MuData object.
971
979
  modality_key: If data is a MuData object, specify which modality to use.
972
980
 
973
981
  Returns:
974
- pd.DataFrame: Effect data frame.
982
+ Effect data frame.
975
983
 
976
984
  Examples:
977
985
  >>> import pertpy as pt
@@ -1005,15 +1013,15 @@ class CompositionalModel2(ABC):
1005
1013
 
1006
1014
  return effect_df
1007
1015
 
1008
- def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"):
1009
- """Get node effect dataframe as printed in the extended summary of a tascCODA model
1016
+ def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
1017
+ """Get node effect dataframe as printed in the extended summary of a tascCODA model.
1010
1018
 
1011
1019
  Args:
1012
1020
  data: AnnData object or MuData object.
1013
1021
  modality_key: If data is a MuData object, specify which modality to use.
1014
1022
 
1015
1023
  Returns:
1016
- pd.DataFrame: Node effect data frame.
1024
+ Node effect data frame.
1017
1025
 
1018
1026
  Examples:
1019
1027
  >>> import pertpy as pt
@@ -1030,7 +1038,6 @@ class CompositionalModel2(ABC):
1030
1038
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1031
1039
  >>> node_effects = tasccoda.get_node_df(mdata)
1032
1040
  """
1033
-
1034
1041
  if isinstance(data, MuData):
1035
1042
  try:
1036
1043
  sample_adata = data[modality_key]
@@ -1043,8 +1050,9 @@ class CompositionalModel2(ABC):
1043
1050
  return sample_adata.uns["scCODA_params"]["node_df"]
1044
1051
 
1045
1052
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
1046
- """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level
1047
- Note: Does not work for spike-and-slab LASSO selection method
1053
+ """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level.
1054
+
1055
+ Note: Does not work for spike-and-slab LASSO selection method.
1048
1056
 
1049
1057
  Args:
1050
1058
  data: AnnData object or MuData object.
@@ -1079,7 +1087,8 @@ class CompositionalModel2(ABC):
1079
1087
 
1080
1088
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
1081
1089
  """Decides which effects of the scCODA model are credible based on an adjustable inclusion probability threshold.
1082
- Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method
1090
+
1091
+ Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method.
1083
1092
 
1084
1093
  Args:
1085
1094
  data: AnnData object or MuData object.
@@ -1087,7 +1096,7 @@ class CompositionalModel2(ABC):
1087
1096
  est_fdr: Estimated false discovery rate. Must be between 0 and 1.
1088
1097
 
1089
1098
  Returns:
1090
- pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1099
+ Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1091
1100
  """
1092
1101
  if isinstance(data, MuData):
1093
1102
  try:
@@ -1109,16 +1118,15 @@ class CompositionalModel2(ABC):
1109
1118
  else:
1110
1119
  _, eff_df = self.summary_prepare(sample_adata, est_fdr=est_fdr) # type: ignore
1111
1120
  # otherwise, get pre-calculated DataFrames. Effect DataFrame is stitched together from varm
1121
+ elif model_type == "tree_agg" and select_type == "sslasso":
1122
+ eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1112
1123
  else:
1113
- if model_type == "tree_agg" and select_type == "sslasso":
1114
- eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1115
- else:
1116
- covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1117
- effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1118
- eff_df = pd.concat(effect_dfs)
1119
- eff_df.index = pd.MultiIndex.from_product(
1120
- (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1121
- )
1124
+ covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1125
+ effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1126
+ eff_df = pd.concat(effect_dfs)
1127
+ eff_df.index = pd.MultiIndex.from_product(
1128
+ (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1129
+ )
1122
1130
 
1123
1131
  out = eff_df["Final Parameter"] != 0
1124
1132
  out.rename("credible change")
@@ -1188,7 +1196,7 @@ class CompositionalModel2(ABC):
1188
1196
  return ax
1189
1197
 
1190
1198
  @_doc_params(common_plot_args=doc_common_plot_args)
1191
- def plot_stacked_barplot( # pragma: no cover
1199
+ def plot_stacked_barplot( # pragma: no cover # noqa: D417
1192
1200
  self,
1193
1201
  data: AnnData | MuData,
1194
1202
  feature_name: str,
@@ -1215,7 +1223,7 @@ class CompositionalModel2(ABC):
1215
1223
  {common_plot_args}
1216
1224
 
1217
1225
  Returns:
1218
- If `return_fig` is `True`, returns the figure, otherwise `None`.
1226
+ If `return_fig` is `True`, returns the Figure, otherwise `None`.
1219
1227
 
1220
1228
  Examples:
1221
1229
  >>> import pertpy as pt
@@ -1230,8 +1238,6 @@ class CompositionalModel2(ABC):
1230
1238
  """
1231
1239
  if isinstance(data, MuData):
1232
1240
  data = data[modality_key]
1233
- if isinstance(data, AnnData):
1234
- data = data
1235
1241
 
1236
1242
  ct_names = data.var.index
1237
1243
 
@@ -1283,7 +1289,7 @@ class CompositionalModel2(ABC):
1283
1289
  return None
1284
1290
 
1285
1291
  @_doc_params(common_plot_args=doc_common_plot_args)
1286
- def plot_effects_barplot( # pragma: no cover
1292
+ def plot_effects_barplot( # pragma: no cover # noqa: D417
1287
1293
  self,
1288
1294
  data: AnnData | MuData,
1289
1295
  *,
@@ -1340,8 +1346,7 @@ class CompositionalModel2(ABC):
1340
1346
  args_barplot = {}
1341
1347
  if isinstance(data, MuData):
1342
1348
  data = data[modality_key]
1343
- if isinstance(data, AnnData):
1344
- data = data
1349
+
1345
1350
  # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1346
1351
  covariate_names = data.uns["scCODA_params"]["covariate_names"]
1347
1352
  if covariates is not None:
@@ -1372,18 +1377,16 @@ class CompositionalModel2(ABC):
1372
1377
 
1373
1378
  plot_df = plot_df.reset_index()
1374
1379
 
1375
- if len(covariate_names_zero) != 0:
1376
- if plot_facets:
1377
- if plot_zero_covariate and not plot_zero_cell_type:
1378
- for covariate_name_zero in covariate_names_zero:
1379
- new_row = {
1380
- "Covariate": covariate_name_zero,
1381
- "Cell Type": "zero",
1382
- "value": 0,
1383
- }
1384
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1385
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1386
- plot_df = plot_df.sort_values(["covariate_"])
1380
+ if len(covariate_names_zero) != 0 and plot_facets and plot_zero_covariate and not plot_zero_cell_type:
1381
+ for covariate_name_zero in covariate_names_zero:
1382
+ new_row = {
1383
+ "Covariate": covariate_name_zero,
1384
+ "Cell Type": "zero",
1385
+ "value": 0,
1386
+ }
1387
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1388
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1389
+ plot_df = plot_df.sort_values(["covariate_"])
1387
1390
  if not plot_zero_cell_type:
1388
1391
  cell_type_names_zero = [
1389
1392
  name
@@ -1427,9 +1430,8 @@ class CompositionalModel2(ABC):
1427
1430
  ax.set_title(covariate_names[i])
1428
1431
  if len(ax.get_xticklabels()) < 5:
1429
1432
  ax.set_aspect(10 / len(ax.get_xticklabels()))
1430
- if len(ax.get_xticklabels()) == 1:
1431
- if ax.get_xticklabels()[0]._text == "zero":
1432
- ax.set_xticks([])
1433
+ if len(ax.get_xticklabels()) == 1 and ax.get_xticklabels()[0]._text == "zero":
1434
+ ax.set_xticks([])
1433
1435
 
1434
1436
  # If not plot as facets, call barplot to plot cell types on the x-axis.
1435
1437
  else:
@@ -1460,6 +1462,7 @@ class CompositionalModel2(ABC):
1460
1462
  ax=ax,
1461
1463
  )
1462
1464
  cell_types = pd.unique(plot_df["Cell Type"])
1465
+ ax.set_xticks(cell_types)
1463
1466
  ax.set_xticklabels(cell_types, rotation=90)
1464
1467
 
1465
1468
  if return_fig and plot_facets:
@@ -1470,7 +1473,7 @@ class CompositionalModel2(ABC):
1470
1473
  return None
1471
1474
 
1472
1475
  @_doc_params(common_plot_args=doc_common_plot_args)
1473
- def plot_boxplots( # pragma: no cover
1476
+ def plot_boxplots( # pragma: no cover # noqa: D417
1474
1477
  self,
1475
1478
  data: AnnData | MuData,
1476
1479
  feature_name: str,
@@ -1532,8 +1535,7 @@ class CompositionalModel2(ABC):
1532
1535
  args_swarmplot = {}
1533
1536
  if isinstance(data, MuData):
1534
1537
  data = data[modality_key]
1535
- if isinstance(data, AnnData):
1536
- data = data
1538
+
1537
1539
  # y scale transformations
1538
1540
  if y_scale == "relative":
1539
1541
  sample_sums = np.sum(data.X, axis=1, keepdims=True)
@@ -1607,10 +1609,7 @@ class CompositionalModel2(ABC):
1607
1609
  )
1608
1610
 
1609
1611
  if add_dots:
1610
- if "hue" in args_swarmplot:
1611
- hue = args_swarmplot.pop("hue")
1612
- else:
1613
- hue = None
1612
+ hue = args_swarmplot.pop("hue") if "hue" in args_swarmplot else None
1614
1613
 
1615
1614
  if hue is None:
1616
1615
  g.map(
@@ -1675,6 +1674,7 @@ class CompositionalModel2(ABC):
1675
1674
  )
1676
1675
 
1677
1676
  cell_types = pd.unique(plot_df["Cell type"])
1677
+ ax.set_xticks(cell_types)
1678
1678
  ax.set_xticklabels(cell_types, rotation=90)
1679
1679
 
1680
1680
  if show_legend:
@@ -1702,7 +1702,7 @@ class CompositionalModel2(ABC):
1702
1702
  return None
1703
1703
 
1704
1704
  @_doc_params(common_plot_args=doc_common_plot_args)
1705
- def plot_rel_abundance_dispersion_plot( # pragma: no cover
1705
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover # noqa: D417
1706
1706
  self,
1707
1707
  data: AnnData | MuData,
1708
1708
  *,
@@ -1750,8 +1750,7 @@ class CompositionalModel2(ABC):
1750
1750
  """
1751
1751
  if isinstance(data, MuData):
1752
1752
  data = data[modality_key]
1753
- if isinstance(data, AnnData):
1754
- data = data
1753
+
1755
1754
  if ax is None:
1756
1755
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1757
1756
 
@@ -1823,13 +1822,13 @@ class CompositionalModel2(ABC):
1823
1822
  return None
1824
1823
 
1825
1824
  @_doc_params(common_plot_args=doc_common_plot_args)
1826
- def plot_draw_tree( # pragma: no cover
1825
+ def plot_draw_tree( # pragma: no cover # noqa: D417
1827
1826
  self,
1828
1827
  data: AnnData | MuData,
1829
1828
  *,
1830
1829
  modality_key: str = "coda",
1831
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1832
- tight_text: bool | None = False,
1830
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1831
+ tight_text: bool = False,
1833
1832
  show_scale: bool | None = False,
1834
1833
  units: Literal["px", "mm", "in"] | None = "px",
1835
1834
  figsize: tuple[float, float] | None = (None, None),
@@ -1837,12 +1836,12 @@ class CompositionalModel2(ABC):
1837
1836
  save: str | bool = False,
1838
1837
  return_fig: bool = False,
1839
1838
  ) -> Tree | None:
1840
- """Plot a tree using input ete3 tree object.
1839
+ """Plot a tree using input ete4 tree object.
1841
1840
 
1842
1841
  Args:
1843
1842
  data: AnnData object or MuData object.
1844
1843
  modality_key: If data is a MuData object, specify which modality to use.
1845
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1844
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1846
1845
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1847
1846
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1848
1847
  show_scale: Include the scale legend in the tree image or not.
@@ -1853,7 +1852,7 @@ class CompositionalModel2(ABC):
1853
1852
  {common_plot_args}
1854
1853
 
1855
1854
  Returns:
1856
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1855
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`) or plot the tree inline (`save = False`)
1857
1856
 
1858
1857
  Examples:
1859
1858
  >>> import pertpy as pt
@@ -1874,7 +1873,8 @@ class CompositionalModel2(ABC):
1874
1873
  .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1875
1874
  """
1876
1875
  try:
1877
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1876
+ from ete4 import Tree
1877
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1878
1878
  except ImportError:
1879
1879
  raise ImportError(
1880
1880
  "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
@@ -1882,8 +1882,6 @@ class CompositionalModel2(ABC):
1882
1882
 
1883
1883
  if isinstance(data, MuData):
1884
1884
  data = data[modality_key]
1885
- if isinstance(data, AnnData):
1886
- data = data
1887
1885
  if isinstance(tree, str):
1888
1886
  tree = data.uns[tree]
1889
1887
 
@@ -1896,7 +1894,7 @@ class CompositionalModel2(ABC):
1896
1894
  tree_style.layout_fn = my_layout
1897
1895
  tree_style.show_scale = show_scale
1898
1896
 
1899
- if save is not None:
1897
+ if save:
1900
1898
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1901
1899
  if return_fig:
1902
1900
  return tree, tree_style
@@ -1904,13 +1902,13 @@ class CompositionalModel2(ABC):
1904
1902
  return None
1905
1903
 
1906
1904
  @_doc_params(common_plot_args=doc_common_plot_args)
1907
- def plot_draw_effects( # pragma: no cover
1905
+ def plot_draw_effects( # pragma: no cover # noqa: D417
1908
1906
  self,
1909
1907
  data: AnnData | MuData,
1910
1908
  covariate: str,
1911
1909
  *,
1912
1910
  modality_key: str = "coda",
1913
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1911
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1914
1912
  show_legend: bool | None = None,
1915
1913
  show_leaf_effects: bool | None = False,
1916
1914
  tight_text: bool | None = False,
@@ -1927,7 +1925,7 @@ class CompositionalModel2(ABC):
1927
1925
  data: AnnData object or MuData object.
1928
1926
  covariate: The covariate, whose effects should be plotted.
1929
1927
  modality_key: If data is a MuData object, specify which modality to use.
1930
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1928
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1931
1929
  show_legend: If show legend of nodes significant effects or not.
1932
1930
  Defaults to False if show_leaf_effects is True.
1933
1931
  show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
@@ -1941,8 +1939,8 @@ class CompositionalModel2(ABC):
1941
1939
  {common_plot_args}
1942
1940
 
1943
1941
  Returns:
1944
- Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`)
1945
- or plot the tree inline (`show = True`)
1942
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`)
1943
+ or plot the tree inline (`save = False`).
1946
1944
 
1947
1945
  Examples:
1948
1946
  >>> import pertpy as pt
@@ -1963,7 +1961,8 @@ class CompositionalModel2(ABC):
1963
1961
  .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
1964
1962
  """
1965
1963
  try:
1966
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1964
+ from ete4 import Tree
1965
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1967
1966
  except ImportError:
1968
1967
  raise ImportError(
1969
1968
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -1971,8 +1970,6 @@ class CompositionalModel2(ABC):
1971
1970
 
1972
1971
  if isinstance(data, MuData):
1973
1972
  data = data[modality_key]
1974
- if isinstance(data, AnnData):
1975
- data = data
1976
1973
  if show_legend is None:
1977
1974
  show_legend = not show_leaf_effects
1978
1975
  elif show_legend:
@@ -2003,18 +2000,18 @@ class CompositionalModel2(ABC):
2003
2000
  n.set_style(nstyle)
2004
2001
  if n.name in node_effs.index:
2005
2002
  e = node_effs.loc[n.name, "Final Parameter"]
2006
- n.add_feature("node_effect", e)
2003
+ n.add_prop("node_effect", e)
2007
2004
  else:
2008
- n.add_feature("node_effect", 0)
2005
+ n.add_prop("node_effect", 0)
2009
2006
  if n.name in leaf_effs.index:
2010
2007
  e = leaf_effs.loc[n.name, "Effect"]
2011
- n.add_feature("leaf_effect", e)
2008
+ n.add_prop("leaf_effect", e)
2012
2009
  else:
2013
- n.add_feature("leaf_effect", 0)
2010
+ n.add_prop("leaf_effect", 0)
2014
2011
 
2015
2012
  # Scale effect values to get nice node sizes
2016
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2017
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2013
+ eff_max = np.max([np.abs(n.props.get("node_effect")) for n in tree2.traverse()])
2014
+ leaf_eff_max = np.max([np.abs(n.props.get("leaf_effect")) for n in tree2.traverse()])
2018
2015
 
2019
2016
  def my_layout(node):
2020
2017
  text_face = TextFace(node.name, tight_text=tight_text)
@@ -2022,10 +2019,10 @@ class CompositionalModel2(ABC):
2022
2019
  faces.add_face_to_node(text_face, node, column=0, aligned=True)
2023
2020
 
2024
2021
  # if node.is_leaf():
2025
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2026
- if np.sign(node.node_effect) == 1:
2022
+ size = (np.abs(node.props.get("node_effect")) * 10 / eff_max) if node.props.get("node_effect") != 0 else 0
2023
+ if np.sign(node.props.get("node_effect")) == 1:
2027
2024
  color = "blue"
2028
- elif np.sign(node.node_effect) == -1:
2025
+ elif np.sign(node.props.get("node_effect")) == -1:
2029
2026
  color = "red"
2030
2027
  else:
2031
2028
  color = "cyan"
@@ -2061,13 +2058,13 @@ class CompositionalModel2(ABC):
2061
2058
  tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2062
2059
 
2063
2060
  if show_leaf_effects:
2064
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2061
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf]
2065
2062
  leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2066
2063
  palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2067
2064
 
2068
2065
  dir_path = Path.cwd()
2069
2066
  dir_path = Path(dir_path / "tree_effect.png")
2070
- tree2.render(dir_path, tree_style=tree_style, units="in")
2067
+ tree2.render(dir_path.as_posix(), tree_style=tree_style, units="in")
2071
2068
  _, ax = plt.subplots(1, 2, figsize=(10, 10))
2072
2069
  sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2073
2070
  img = mpimg.imread(dir_path)
@@ -2098,7 +2095,7 @@ class CompositionalModel2(ABC):
2098
2095
  return None
2099
2096
 
2100
2097
  @_doc_params(common_plot_args=doc_common_plot_args)
2101
- def plot_effects_umap( # pragma: no cover
2098
+ def plot_effects_umap( # pragma: no cover # noqa: D417
2102
2099
  self,
2103
2100
  mdata: MuData,
2104
2101
  effect_name: str | list | None,
@@ -2211,7 +2208,7 @@ class CompositionalModel2(ABC):
2211
2208
  def get_a(
2212
2209
  tree: tt.core.ToyTree,
2213
2210
  ) -> tuple[np.ndarray, int]:
2214
- """Calculate ancestor matrix from a toytree tree
2211
+ """Calculate ancestor matrix from a toytree tree.
2215
2212
 
2216
2213
  Args:
2217
2214
  tree: A toytree tree object.
@@ -2264,16 +2261,14 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2264
2261
  A_T = A.T
2265
2262
  unq, count = np.unique(A_T, axis=0, return_counts=True)
2266
2263
 
2267
- repeated_idx = []
2268
- for repeated_group in unq[count > 1]:
2269
- repeated_idx.append(np.argwhere(np.all(A_T == repeated_group, axis=1)).ravel())
2264
+ repeated_idx = [np.argwhere(np.all(repeated_group == A_T, axis=1)).ravel() for repeated_group in unq[count > 1]]
2270
2265
 
2271
2266
  nodes_to_delete = [i for idx in repeated_idx for i in idx[1:]]
2272
2267
 
2273
2268
  # _coords.update() scrambles the idx of leaves. Therefore, keep track of it here
2274
2269
  tree_new = tree.copy()
2275
2270
  for node in tree_new.treenode.traverse():
2276
- node.add_feature("idx_orig", node.idx)
2271
+ node.add_prop("idx_orig", node.idx)
2277
2272
 
2278
2273
  for n in nodes_to_delete:
2279
2274
  node = tree_new.idx_dict[n]
@@ -2289,21 +2284,16 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2289
2284
  return tree_new
2290
2285
 
2291
2286
 
2292
- def traverse(df_, a, i, innerl):
2293
- """
2294
- Helper function for df2newick
2295
- Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram
2287
+ def traverse(df_: pd.DataFrame, a: str, i: int, innerl: bool) -> str:
2288
+ """Helper function for df2newick.
2289
+
2290
+ Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram.
2296
2291
  """
2297
2292
  if i + 1 < df_.shape[1]:
2298
2293
  a_inner = pd.unique(df_.loc[np.where(df_.iloc[:, i] == a)].iloc[:, i + 1])
2299
2294
 
2300
- desc = []
2301
- for b in a_inner:
2302
- desc.append(traverse(df_, b, i + 1, innerl))
2303
- if innerl:
2304
- il = a
2305
- else:
2306
- il = ""
2295
+ desc = [traverse(df_, b, i + 1, innerl) for b in a_inner]
2296
+ il = a if innerl else ""
2307
2297
  out = f"({','.join(desc)}){il}"
2308
2298
  else:
2309
2299
  out = a
@@ -2327,9 +2317,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
2327
2317
  df_tax = df.loc[:, [x for x in levels if x in df.columns]]
2328
2318
 
2329
2319
  alevel = pd.unique(df_tax.iloc[:, 0])
2330
- strs = []
2331
- for a in alevel:
2332
- strs.append(traverse(df_tax, a, 0, inner_label))
2320
+ strs = [traverse(df_tax, a, 0, inner_label) for a in alevel]
2333
2321
 
2334
2322
  newick = f"({','.join(strs)});"
2335
2323
  return newick
@@ -2340,10 +2328,10 @@ def get_a_2(
2340
2328
  leaf_order: list[str] = None,
2341
2329
  node_order: list[str] = None,
2342
2330
  ) -> tuple[np.ndarray, int]:
2343
- """Calculate ancestor matrix from a ete3 tree.
2331
+ """Calculate ancestor matrix from a ete4 tree.
2344
2332
 
2345
2333
  Args:
2346
- tree: A ete3 tree object.
2334
+ tree: A ete4 tree object.
2347
2335
  leaf_order: List of leaf names how they should appear as the rows of the ancestor matrix.
2348
2336
  If None, the ordering will be as in `tree.iter_leaves()`
2349
2337
  node_order: List of node names how they should appear as the columns of the ancestor matrix
@@ -2358,29 +2346,29 @@ def get_a_2(
2358
2346
  number of nodes in the tree, excluding the root node
2359
2347
  """
2360
2348
  try:
2361
- import ete3 as ete
2349
+ import ete4 as ete
2362
2350
  except ImportError:
2363
2351
  raise ImportError(
2364
2352
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2365
2353
  ) from None
2366
2354
 
2367
- n_tips = len(tree.get_leaves())
2368
- n_nodes = len(tree.get_descendants())
2355
+ n_tips = len(list(tree.leaves()))
2356
+ n_nodes = len(list(tree.descendants()))
2369
2357
 
2370
- node_names = [n.name for n in tree.iter_descendants()]
2358
+ node_names = [n.name for n in tree.descendants()]
2371
2359
  duplicates = [x for x in node_names if node_names.count(x) > 1]
2372
2360
  if len(duplicates) > 0:
2373
2361
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")
2374
2362
 
2375
2363
  # Initialize ancestor matrix
2376
2364
  A_ = pd.DataFrame(np.zeros((n_tips, n_nodes)))
2377
- A_.index = tree.get_leaf_names()
2378
- A_.columns = [n.name for n in tree.iter_descendants()]
2365
+ A_.index = tree.leaf_names()
2366
+ A_.columns = [n.name for n in tree.descendants()]
2379
2367
 
2380
2368
  # Fill in 1's for all connections
2381
- for node in tree.iter_descendants():
2382
- for leaf in tree.get_leaves():
2383
- if leaf in node.get_leaves():
2369
+ for node in tree.descendants():
2370
+ for leaf in tree.leaves():
2371
+ if leaf in node.leaves():
2384
2372
  A_.loc[leaf.name, node.name] = 1
2385
2373
 
2386
2374
  # Order rows and columns
@@ -2394,15 +2382,15 @@ def get_a_2(
2394
2382
 
2395
2383
 
2396
2384
  def collapse_singularities_2(tree: Tree) -> Tree:
2397
- """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
2385
+ """Collapses (deletes) nodes in a ete4 tree that are singularities (have only one child).
2398
2386
 
2399
2387
  Args:
2400
- tree: A ete3 tree object
2388
+ tree: A ete4 tree object
2401
2389
 
2402
2390
  Returns:
2403
- A ete3 tree without singularities.
2391
+ A ete4 tree without singularities.
2404
2392
  """
2405
- for node in tree.iter_descendants():
2393
+ for node in tree.descendants():
2406
2394
  if len(node.get_children()) == 1:
2407
2395
  node.delete()
2408
2396
 
@@ -2427,13 +2415,10 @@ def linkage_to_newick(
2427
2415
  tree = sp_hierarchy.to_tree(Z, False)
2428
2416
 
2429
2417
  def build_newick(node, newick, parentdist, leaf_names):
2430
- if node.is_leaf():
2418
+ if node.is_leaf:
2431
2419
  return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
2432
2420
  else:
2433
- if len(newick) > 0:
2434
- newick = f"):{(parentdist - node.dist) / 2}{newick}"
2435
- else:
2436
- newick = ");"
2421
+ newick = f"):{(parentdist - node.dist) / 2}{newick}" if len(newick) > 0 else ");"
2437
2422
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
2438
2423
  newick = build_newick(node.get_right(), f",{newick}", node.dist, leaf_names)
2439
2424
  newick = f"({newick}"
@@ -2478,10 +2463,10 @@ def import_tree(
2478
2463
 
2479
2464
  See `key_added` parameter description for the storage path of tree.
2480
2465
 
2481
- tree: A ete3 tree object.
2466
+ tree: A ete4 tree object.
2482
2467
  """
2483
2468
  try:
2484
- import ete3 as ete
2469
+ import ete4 as ete
2485
2470
  except ImportError:
2486
2471
  raise ImportError(
2487
2472
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -2506,32 +2491,33 @@ def import_tree(
2506
2491
  data_1.uns["dendrogram_cell_label"]["linkage"],
2507
2492
  labels=data_1.uns["dendrogram_cell_label"]["categories_ordered"],
2508
2493
  )
2509
- tree = ete.Tree(newick, format=1)
2494
+ tree = ete.Tree(newick, parser=1)
2510
2495
  node_id = 0
2511
- for n in tree.iter_descendants():
2512
- if not n.is_leaf():
2496
+ for n in tree.descendants():
2497
+ if not n.is_leaf:
2513
2498
  n.name = str(node_id)
2514
2499
  node_id += 1
2515
2500
  elif levels_orig is not None:
2516
2501
  newick = df2newick(data_1.obs.reset_index(), levels=levels_orig)
2517
- tree = ete.Tree(newick, format=8)
2502
+ tree = ete.Tree(newick, parser=8)
2503
+
2518
2504
  if add_level_name:
2519
- for n in tree.iter_descendants():
2520
- if not n.is_leaf():
2521
- dist = n.get_distance(n, tree)
2505
+ for n in tree.descendants():
2506
+ if not n.is_leaf:
2507
+ dist = n.get_distance(n, tree, topological=True)
2522
2508
  n.name = f"{levels_orig[int(dist) - 1]}_{n.name}"
2523
2509
  elif levels_agg is not None:
2524
2510
  newick = df2newick(data_2.var.reset_index(), levels=levels_agg)
2525
- tree = ete.Tree(newick, format=8)
2511
+ tree = ete.Tree(newick, parser=8)
2526
2512
  if add_level_name:
2527
- for n in tree.iter_descendants():
2528
- if not n.is_leaf():
2529
- dist = n.get_distance(n, tree)
2513
+ for n in tree.descendants():
2514
+ if not n.is_leaf:
2515
+ dist = n.get_distance(n, tree, topological=True)
2530
2516
  n.name = f"{levels_agg[int(dist) - 1]}_{n.name}"
2531
2517
  else:
2532
2518
  raise ValueError("Either dendrogram_key, levels_orig or levels_agg must be specified!")
2533
2519
 
2534
- node_names = [n.name for n in tree.iter_descendants()]
2520
+ node_names = [n.name for n in tree.descendants()]
2535
2521
  duplicates = {x for x in node_names if node_names.count(x) > 1}
2536
2522
  if len(duplicates) > 0:
2537
2523
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")