pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
33
33
 
34
34
  import numpyro as npy
35
35
  import toytree as tt
36
- from ete3 import Tree
36
+ from ete4 import Tree
37
37
  from jax._src.typing import Array
38
38
  from matplotlib.axes import Axes
39
39
  from matplotlib.colors import Colormap
@@ -198,7 +198,7 @@ class CompositionalModel2(ABC):
198
198
  *args,
199
199
  **kwargs,
200
200
  ):
201
- """Background function that executes any numpyro MCMC algorithm and processes its results
201
+ """Background function that executes any numpyro MCMC algorithm and processes its results.
202
202
 
203
203
  Args:
204
204
  sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -294,6 +294,8 @@ class CompositionalModel2(ABC):
294
294
  num_warmup: Number of burn-in (warmup) samples.
295
295
  rng_key: The rng state used.
296
296
  copy: Return a copy instead of writing to adata.
297
+ *args: Additional args passed to numpyro NUTS
298
+ **kwargs: Additional kwargs passed to numpyro NUTS
297
299
 
298
300
  Returns:
299
301
  Calls `self.__run_mcmc`
@@ -347,6 +349,8 @@ class CompositionalModel2(ABC):
347
349
  num_warmup: Number of burn-in (warmup) samples.
348
350
  rng_key: The rng state used. If None, a random state will be selected.
349
351
  copy: Return a copy instead of writing to adata.
352
+ *args: Additional args passed to numpyro HMC
353
+ **kwargs: Additional kwargs passed to numpyro HMC
350
354
 
351
355
  Examples:
352
356
  >>> import pertpy as pt
@@ -396,7 +400,8 @@ class CompositionalModel2(ABC):
396
400
  self, sample_adata: AnnData, est_fdr: float = 0.05, *args, **kwargs
397
401
  ) -> tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
398
402
  """Generates summary dataframes for intercepts, effects and node-level effect (if using tree aggregation).
399
- This function builds on and supports all functionalities from ``az.summary``.
403
+
404
+ This function builds on and supports all functionalities from ``az.summary``.
400
405
 
401
406
  Args:
402
407
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -405,7 +410,7 @@ class CompositionalModel2(ABC):
405
410
  kwargs: Passed to ``az.summary``
406
411
 
407
412
  Returns:
408
- Tuple[pd.DataFrame, pd.DataFrame] or Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: Intercept, effect and node-level DataFrames
413
+ Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame] or Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame, :class:pandas.DataFrame]: Intercept, effect and node-level DataFrames
409
414
 
410
415
  intercept_df
411
416
  Summary of intercept parameters. Contains one row per cell type.
@@ -435,7 +440,7 @@ class CompositionalModel2(ABC):
435
440
  - Delta: Decision boundary value - threshold of practical significance
436
441
  - Is credible: Boolean indicator whether effect is credible
437
442
 
438
- Examples:
443
+ Examples:
439
444
  >>> import pertpy as pt
440
445
  >>> haber_cells = pt.dt.haber_2017_regions()
441
446
  >>> sccoda = pt.tl.Sccoda()
@@ -684,7 +689,7 @@ class CompositionalModel2(ABC):
684
689
 
685
690
  if fdr < alpha:
686
691
  # ceiling with 3 decimals precision
687
- c = np.floor(c * 10**3) / 10**3
692
+ c = np.floor(c * 10**3) / 10**3 # noqa: PLW2901
688
693
  return c, fdr
689
694
  return 1.0, 0
690
695
 
@@ -737,7 +742,8 @@ class CompositionalModel2(ABC):
737
742
  node_df: pd.DataFrame,
738
743
  ) -> pd.DataFrame:
739
744
  """Evaluation of MCMC results for node-level effect parameters. This function is only used within self.summary_prepare.
740
- This function determines whether node-level effects are credible or not
745
+
746
+ This function determines whether node-level effects are credible or not.
741
747
 
742
748
  Args:
743
749
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -850,7 +856,7 @@ class CompositionalModel2(ABC):
850
856
  table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
851
857
  table.add_column("Name", justify="left", style="cyan")
852
858
  table.add_column("Value", justify="left")
853
- table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
859
+ table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
854
860
  table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
855
861
  table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
856
862
  if extended:
@@ -932,15 +938,15 @@ class CompositionalModel2(ABC):
932
938
  )
933
939
  console.print(table)
934
940
 
935
- def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda"):
936
- """Get intercept dataframe as printed in the extended summary
941
+ def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
942
+ """Get intercept dataframe as printed in the extended summary.
937
943
 
938
944
  Args:
939
945
  data: AnnData object or MuData object.
940
946
  modality_key: If data is a MuData object, specify which modality to use.
941
947
 
942
948
  Returns:
943
- pd.DataFrame: Intercept data frame.
949
+ Intercept data frame.
944
950
 
945
951
  Examples:
946
952
  >>> import pertpy as pt
@@ -963,15 +969,15 @@ class CompositionalModel2(ABC):
963
969
 
964
970
  return sample_adata.varm["intercept_df"]
965
971
 
966
- def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda"):
967
- """Get effect dataframe as printed in the extended summary
972
+ def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
973
+ """Get effect dataframe as printed in the extended summary.
968
974
 
969
975
  Args:
970
976
  data: AnnData object or MuData object.
971
977
  modality_key: If data is a MuData object, specify which modality to use.
972
978
 
973
979
  Returns:
974
- pd.DataFrame: Effect data frame.
980
+ Effect data frame.
975
981
 
976
982
  Examples:
977
983
  >>> import pertpy as pt
@@ -1005,15 +1011,15 @@ class CompositionalModel2(ABC):
1005
1011
 
1006
1012
  return effect_df
1007
1013
 
1008
- def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"):
1009
- """Get node effect dataframe as printed in the extended summary of a tascCODA model
1014
+ def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
1015
+ """Get node effect dataframe as printed in the extended summary of a tascCODA model.
1010
1016
 
1011
1017
  Args:
1012
1018
  data: AnnData object or MuData object.
1013
1019
  modality_key: If data is a MuData object, specify which modality to use.
1014
1020
 
1015
1021
  Returns:
1016
- pd.DataFrame: Node effect data frame.
1022
+ Node effect data frame.
1017
1023
 
1018
1024
  Examples:
1019
1025
  >>> import pertpy as pt
@@ -1030,7 +1036,6 @@ class CompositionalModel2(ABC):
1030
1036
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1031
1037
  >>> node_effects = tasccoda.get_node_df(mdata)
1032
1038
  """
1033
-
1034
1039
  if isinstance(data, MuData):
1035
1040
  try:
1036
1041
  sample_adata = data[modality_key]
@@ -1043,8 +1048,9 @@ class CompositionalModel2(ABC):
1043
1048
  return sample_adata.uns["scCODA_params"]["node_df"]
1044
1049
 
1045
1050
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
1046
- """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level
1047
- Note: Does not work for spike-and-slab LASSO selection method
1051
+ """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level.
1052
+
1053
+ Note: Does not work for spike-and-slab LASSO selection method.
1048
1054
 
1049
1055
  Args:
1050
1056
  data: AnnData object or MuData object.
@@ -1079,7 +1085,8 @@ class CompositionalModel2(ABC):
1079
1085
 
1080
1086
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
1081
1087
  """Decides which effects of the scCODA model are credible based on an adjustable inclusion probability threshold.
1082
- Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method
1088
+
1089
+ Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method.
1083
1090
 
1084
1091
  Args:
1085
1092
  data: AnnData object or MuData object.
@@ -1087,7 +1094,7 @@ class CompositionalModel2(ABC):
1087
1094
  est_fdr: Estimated false discovery rate. Must be between 0 and 1.
1088
1095
 
1089
1096
  Returns:
1090
- pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1097
+ Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1091
1098
  """
1092
1099
  if isinstance(data, MuData):
1093
1100
  try:
@@ -1109,16 +1116,15 @@ class CompositionalModel2(ABC):
1109
1116
  else:
1110
1117
  _, eff_df = self.summary_prepare(sample_adata, est_fdr=est_fdr) # type: ignore
1111
1118
  # otherwise, get pre-calculated DataFrames. Effect DataFrame is stitched together from varm
1119
+ elif model_type == "tree_agg" and select_type == "sslasso":
1120
+ eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1112
1121
  else:
1113
- if model_type == "tree_agg" and select_type == "sslasso":
1114
- eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1115
- else:
1116
- covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1117
- effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1118
- eff_df = pd.concat(effect_dfs)
1119
- eff_df.index = pd.MultiIndex.from_product(
1120
- (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1121
- )
1122
+ covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1123
+ effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1124
+ eff_df = pd.concat(effect_dfs)
1125
+ eff_df.index = pd.MultiIndex.from_product(
1126
+ (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1127
+ )
1122
1128
 
1123
1129
  out = eff_df["Final Parameter"] != 0
1124
1130
  out.rename("credible change")
@@ -1188,7 +1194,7 @@ class CompositionalModel2(ABC):
1188
1194
  return ax
1189
1195
 
1190
1196
  @_doc_params(common_plot_args=doc_common_plot_args)
1191
- def plot_stacked_barplot( # pragma: no cover
1197
+ def plot_stacked_barplot( # pragma: no cover # noqa: D417
1192
1198
  self,
1193
1199
  data: AnnData | MuData,
1194
1200
  feature_name: str,
@@ -1199,7 +1205,6 @@ class CompositionalModel2(ABC):
1199
1205
  level_order: list[str] = None,
1200
1206
  figsize: tuple[float, float] | None = None,
1201
1207
  dpi: int | None = 100,
1202
- show: bool = True,
1203
1208
  return_fig: bool = False,
1204
1209
  ) -> Figure | None:
1205
1210
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
@@ -1216,7 +1221,7 @@ class CompositionalModel2(ABC):
1216
1221
  {common_plot_args}
1217
1222
 
1218
1223
  Returns:
1219
- If `return_fig` is `True`, returns the figure, otherwise `None`.
1224
+ If `return_fig` is `True`, returns the Figure, otherwise `None`.
1220
1225
 
1221
1226
  Examples:
1222
1227
  >>> import pertpy as pt
@@ -1231,8 +1236,6 @@ class CompositionalModel2(ABC):
1231
1236
  """
1232
1237
  if isinstance(data, MuData):
1233
1238
  data = data[modality_key]
1234
- if isinstance(data, AnnData):
1235
- data = data
1236
1239
 
1237
1240
  ct_names = data.var.index
1238
1241
 
@@ -1278,14 +1281,13 @@ class CompositionalModel2(ABC):
1278
1281
  show_legend=show_legend,
1279
1282
  )
1280
1283
 
1281
- if show:
1282
- plt.show()
1283
1284
  if return_fig:
1284
1285
  return plt.gcf()
1286
+ plt.show()
1285
1287
  return None
1286
1288
 
1287
1289
  @_doc_params(common_plot_args=doc_common_plot_args)
1288
- def plot_effects_barplot( # pragma: no cover
1290
+ def plot_effects_barplot( # pragma: no cover # noqa: D417
1289
1291
  self,
1290
1292
  data: AnnData | MuData,
1291
1293
  *,
@@ -1300,7 +1302,6 @@ class CompositionalModel2(ABC):
1300
1302
  args_barplot: dict | None = None,
1301
1303
  figsize: tuple[float, float] | None = None,
1302
1304
  dpi: int | None = 100,
1303
- show: bool = True,
1304
1305
  return_fig: bool = False,
1305
1306
  ) -> Figure | None:
1306
1307
  """Barplot visualization for effects.
@@ -1343,8 +1344,7 @@ class CompositionalModel2(ABC):
1343
1344
  args_barplot = {}
1344
1345
  if isinstance(data, MuData):
1345
1346
  data = data[modality_key]
1346
- if isinstance(data, AnnData):
1347
- data = data
1347
+
1348
1348
  # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1349
1349
  covariate_names = data.uns["scCODA_params"]["covariate_names"]
1350
1350
  if covariates is not None:
@@ -1375,18 +1375,16 @@ class CompositionalModel2(ABC):
1375
1375
 
1376
1376
  plot_df = plot_df.reset_index()
1377
1377
 
1378
- if len(covariate_names_zero) != 0:
1379
- if plot_facets:
1380
- if plot_zero_covariate and not plot_zero_cell_type:
1381
- for covariate_name_zero in covariate_names_zero:
1382
- new_row = {
1383
- "Covariate": covariate_name_zero,
1384
- "Cell Type": "zero",
1385
- "value": 0,
1386
- }
1387
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1388
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1389
- plot_df = plot_df.sort_values(["covariate_"])
1378
+ if len(covariate_names_zero) != 0 and plot_facets and plot_zero_covariate and not plot_zero_cell_type:
1379
+ for covariate_name_zero in covariate_names_zero:
1380
+ new_row = {
1381
+ "Covariate": covariate_name_zero,
1382
+ "Cell Type": "zero",
1383
+ "value": 0,
1384
+ }
1385
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1386
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1387
+ plot_df = plot_df.sort_values(["covariate_"])
1390
1388
  if not plot_zero_cell_type:
1391
1389
  cell_type_names_zero = [
1392
1390
  name
@@ -1430,9 +1428,8 @@ class CompositionalModel2(ABC):
1430
1428
  ax.set_title(covariate_names[i])
1431
1429
  if len(ax.get_xticklabels()) < 5:
1432
1430
  ax.set_aspect(10 / len(ax.get_xticklabels()))
1433
- if len(ax.get_xticklabels()) == 1:
1434
- if ax.get_xticklabels()[0]._text == "zero":
1435
- ax.set_xticks([])
1431
+ if len(ax.get_xticklabels()) == 1 and ax.get_xticklabels()[0]._text == "zero":
1432
+ ax.set_xticks([])
1436
1433
 
1437
1434
  # If not plot as facets, call barplot to plot cell types on the x-axis.
1438
1435
  else:
@@ -1463,16 +1460,18 @@ class CompositionalModel2(ABC):
1463
1460
  ax=ax,
1464
1461
  )
1465
1462
  cell_types = pd.unique(plot_df["Cell Type"])
1463
+ ax.set_xticks(cell_types)
1466
1464
  ax.set_xticklabels(cell_types, rotation=90)
1467
1465
 
1468
- if show:
1469
- plt.show()
1470
- if return_fig:
1466
+ if return_fig and plot_facets:
1467
+ return g
1468
+ if return_fig and not plot_facets:
1471
1469
  return plt.gcf()
1470
+ plt.show()
1472
1471
  return None
1473
1472
 
1474
1473
  @_doc_params(common_plot_args=doc_common_plot_args)
1475
- def plot_boxplots( # pragma: no cover
1474
+ def plot_boxplots( # pragma: no cover # noqa: D417
1476
1475
  self,
1477
1476
  data: AnnData | MuData,
1478
1477
  feature_name: str,
@@ -1489,7 +1488,6 @@ class CompositionalModel2(ABC):
1489
1488
  level_order: list[str] = None,
1490
1489
  figsize: tuple[float, float] | None = None,
1491
1490
  dpi: int | None = 100,
1492
- show: bool = True,
1493
1491
  return_fig: bool = False,
1494
1492
  ) -> Figure | None:
1495
1493
  """Grouped boxplot visualization.
@@ -1535,8 +1533,7 @@ class CompositionalModel2(ABC):
1535
1533
  args_swarmplot = {}
1536
1534
  if isinstance(data, MuData):
1537
1535
  data = data[modality_key]
1538
- if isinstance(data, AnnData):
1539
- data = data
1536
+
1540
1537
  # y scale transformations
1541
1538
  if y_scale == "relative":
1542
1539
  sample_sums = np.sum(data.X, axis=1, keepdims=True)
@@ -1610,10 +1607,7 @@ class CompositionalModel2(ABC):
1610
1607
  )
1611
1608
 
1612
1609
  if add_dots:
1613
- if "hue" in args_swarmplot:
1614
- hue = args_swarmplot.pop("hue")
1615
- else:
1616
- hue = None
1610
+ hue = args_swarmplot.pop("hue") if "hue" in args_swarmplot else None
1617
1611
 
1618
1612
  if hue is None:
1619
1613
  g.map(
@@ -1678,6 +1672,7 @@ class CompositionalModel2(ABC):
1678
1672
  )
1679
1673
 
1680
1674
  cell_types = pd.unique(plot_df["Cell type"])
1675
+ ax.set_xticks(cell_types)
1681
1676
  ax.set_xticklabels(cell_types, rotation=90)
1682
1677
 
1683
1678
  if show_legend:
@@ -1697,14 +1692,15 @@ class CompositionalModel2(ABC):
1697
1692
  title=feature_name,
1698
1693
  )
1699
1694
 
1700
- if show:
1701
- plt.show()
1702
- if return_fig:
1695
+ if return_fig and plot_facets:
1696
+ return g
1697
+ if return_fig and not plot_facets:
1703
1698
  return plt.gcf()
1699
+ plt.show()
1704
1700
  return None
1705
1701
 
1706
1702
  @_doc_params(common_plot_args=doc_common_plot_args)
1707
- def plot_rel_abundance_dispersion_plot( # pragma: no cover
1703
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover # noqa: D417
1708
1704
  self,
1709
1705
  data: AnnData | MuData,
1710
1706
  *,
@@ -1716,7 +1712,6 @@ class CompositionalModel2(ABC):
1716
1712
  figsize: tuple[float, float] | None = None,
1717
1713
  dpi: int | None = 100,
1718
1714
  ax: plt.Axes | None = None,
1719
- show: bool = True,
1720
1715
  return_fig: bool = False,
1721
1716
  ) -> Figure | None:
1722
1717
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
@@ -1753,8 +1748,7 @@ class CompositionalModel2(ABC):
1753
1748
  """
1754
1749
  if isinstance(data, MuData):
1755
1750
  data = data[modality_key]
1756
- if isinstance(data, AnnData):
1757
- data = data
1751
+
1758
1752
  if ax is None:
1759
1753
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1760
1754
 
@@ -1820,34 +1814,32 @@ class CompositionalModel2(ABC):
1820
1814
 
1821
1815
  ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1822
1816
 
1823
- if show:
1824
- plt.show()
1825
1817
  if return_fig:
1826
1818
  return plt.gcf()
1819
+ plt.show()
1827
1820
  return None
1828
1821
 
1829
1822
  @_doc_params(common_plot_args=doc_common_plot_args)
1830
- def plot_draw_tree( # pragma: no cover
1823
+ def plot_draw_tree( # pragma: no cover # noqa: D417
1831
1824
  self,
1832
1825
  data: AnnData | MuData,
1833
1826
  *,
1834
1827
  modality_key: str = "coda",
1835
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1836
- tight_text: bool | None = False,
1828
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1829
+ tight_text: bool = False,
1837
1830
  show_scale: bool | None = False,
1838
1831
  units: Literal["px", "mm", "in"] | None = "px",
1839
1832
  figsize: tuple[float, float] | None = (None, None),
1840
1833
  dpi: int | None = 100,
1841
1834
  save: str | bool = False,
1842
- show: bool = True,
1843
1835
  return_fig: bool = False,
1844
1836
  ) -> Tree | None:
1845
- """Plot a tree using input ete3 tree object.
1837
+ """Plot a tree using input ete4 tree object.
1846
1838
 
1847
1839
  Args:
1848
1840
  data: AnnData object or MuData object.
1849
1841
  modality_key: If data is a MuData object, specify which modality to use.
1850
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1842
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1851
1843
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1852
1844
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1853
1845
  show_scale: Include the scale legend in the tree image or not.
@@ -1858,7 +1850,7 @@ class CompositionalModel2(ABC):
1858
1850
  {common_plot_args}
1859
1851
 
1860
1852
  Returns:
1861
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1853
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`) or plot the tree inline (`save = False`)
1862
1854
 
1863
1855
  Examples:
1864
1856
  >>> import pertpy as pt
@@ -1879,7 +1871,8 @@ class CompositionalModel2(ABC):
1879
1871
  .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1880
1872
  """
1881
1873
  try:
1882
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1874
+ from ete4 import Tree
1875
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1883
1876
  except ImportError:
1884
1877
  raise ImportError(
1885
1878
  "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
@@ -1887,8 +1880,6 @@ class CompositionalModel2(ABC):
1887
1880
 
1888
1881
  if isinstance(data, MuData):
1889
1882
  data = data[modality_key]
1890
- if isinstance(data, AnnData):
1891
- data = data
1892
1883
  if isinstance(tree, str):
1893
1884
  tree = data.uns[tree]
1894
1885
 
@@ -1901,22 +1892,21 @@ class CompositionalModel2(ABC):
1901
1892
  tree_style.layout_fn = my_layout
1902
1893
  tree_style.show_scale = show_scale
1903
1894
 
1904
- if save is not None:
1895
+ if save:
1905
1896
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1906
- if show:
1907
- return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1908
1897
  if return_fig:
1909
1898
  return tree, tree_style
1899
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1910
1900
  return None
1911
1901
 
1912
1902
  @_doc_params(common_plot_args=doc_common_plot_args)
1913
- def plot_draw_effects( # pragma: no cover
1903
+ def plot_draw_effects( # pragma: no cover # noqa: D417
1914
1904
  self,
1915
1905
  data: AnnData | MuData,
1916
1906
  covariate: str,
1917
1907
  *,
1918
1908
  modality_key: str = "coda",
1919
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1909
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1920
1910
  show_legend: bool | None = None,
1921
1911
  show_leaf_effects: bool | None = False,
1922
1912
  tight_text: bool | None = False,
@@ -1925,7 +1915,6 @@ class CompositionalModel2(ABC):
1925
1915
  figsize: tuple[float, float] | None = (None, None),
1926
1916
  dpi: int | None = 100,
1927
1917
  save: str | bool = False,
1928
- show: bool = True,
1929
1918
  return_fig: bool = False,
1930
1919
  ) -> Tree | None:
1931
1920
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
@@ -1934,7 +1923,7 @@ class CompositionalModel2(ABC):
1934
1923
  data: AnnData object or MuData object.
1935
1924
  covariate: The covariate, whose effects should be plotted.
1936
1925
  modality_key: If data is a MuData object, specify which modality to use.
1937
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1926
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1938
1927
  show_legend: If show legend of nodes significant effects or not.
1939
1928
  Defaults to False if show_leaf_effects is True.
1940
1929
  show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
@@ -1948,8 +1937,8 @@ class CompositionalModel2(ABC):
1948
1937
  {common_plot_args}
1949
1938
 
1950
1939
  Returns:
1951
- Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`)
1952
- or plot the tree inline (`show = True`)
1940
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`)
1941
+ or plot the tree inline (`save = False`).
1953
1942
 
1954
1943
  Examples:
1955
1944
  >>> import pertpy as pt
@@ -1970,7 +1959,8 @@ class CompositionalModel2(ABC):
1970
1959
  .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
1971
1960
  """
1972
1961
  try:
1973
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1962
+ from ete4 import Tree
1963
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1974
1964
  except ImportError:
1975
1965
  raise ImportError(
1976
1966
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -1978,8 +1968,6 @@ class CompositionalModel2(ABC):
1978
1968
 
1979
1969
  if isinstance(data, MuData):
1980
1970
  data = data[modality_key]
1981
- if isinstance(data, AnnData):
1982
- data = data
1983
1971
  if show_legend is None:
1984
1972
  show_legend = not show_leaf_effects
1985
1973
  elif show_legend:
@@ -2010,18 +1998,18 @@ class CompositionalModel2(ABC):
2010
1998
  n.set_style(nstyle)
2011
1999
  if n.name in node_effs.index:
2012
2000
  e = node_effs.loc[n.name, "Final Parameter"]
2013
- n.add_feature("node_effect", e)
2001
+ n.add_prop("node_effect", e)
2014
2002
  else:
2015
- n.add_feature("node_effect", 0)
2003
+ n.add_prop("node_effect", 0)
2016
2004
  if n.name in leaf_effs.index:
2017
2005
  e = leaf_effs.loc[n.name, "Effect"]
2018
- n.add_feature("leaf_effect", e)
2006
+ n.add_prop("leaf_effect", e)
2019
2007
  else:
2020
- n.add_feature("leaf_effect", 0)
2008
+ n.add_prop("leaf_effect", 0)
2021
2009
 
2022
2010
  # Scale effect values to get nice node sizes
2023
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2024
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2011
+ eff_max = np.max([np.abs(n.props.get("node_effect")) for n in tree2.traverse()])
2012
+ leaf_eff_max = np.max([np.abs(n.props.get("leaf_effect")) for n in tree2.traverse()])
2025
2013
 
2026
2014
  def my_layout(node):
2027
2015
  text_face = TextFace(node.name, tight_text=tight_text)
@@ -2029,10 +2017,10 @@ class CompositionalModel2(ABC):
2029
2017
  faces.add_face_to_node(text_face, node, column=0, aligned=True)
2030
2018
 
2031
2019
  # if node.is_leaf():
2032
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2033
- if np.sign(node.node_effect) == 1:
2020
+ size = (np.abs(node.props.get("node_effect")) * 10 / eff_max) if node.props.get("node_effect") != 0 else 0
2021
+ if np.sign(node.props.get("node_effect")) == 1:
2034
2022
  color = "blue"
2035
- elif np.sign(node.node_effect) == -1:
2023
+ elif np.sign(node.props.get("node_effect")) == -1:
2036
2024
  color = "red"
2037
2025
  else:
2038
2026
  color = "cyan"
@@ -2068,13 +2056,13 @@ class CompositionalModel2(ABC):
2068
2056
  tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2069
2057
 
2070
2058
  if show_leaf_effects:
2071
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2059
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf]
2072
2060
  leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2073
2061
  palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2074
2062
 
2075
2063
  dir_path = Path.cwd()
2076
2064
  dir_path = Path(dir_path / "tree_effect.png")
2077
- tree2.render(dir_path, tree_style=tree_style, units="in")
2065
+ tree2.render(dir_path.as_posix(), tree_style=tree_style, units="in")
2078
2066
  _, ax = plt.subplots(1, 2, figsize=(10, 10))
2079
2067
  sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2080
2068
  img = mpimg.imread(dir_path)
@@ -2092,19 +2080,20 @@ class CompositionalModel2(ABC):
2092
2080
 
2093
2081
  if save:
2094
2082
  plt.savefig(save)
2083
+ if return_fig:
2084
+ return plt.gcf()
2095
2085
 
2096
- if save and not show_leaf_effects:
2097
- tree2.render(save, tree_style=tree_style, units=units)
2098
- if show:
2099
- if not show_leaf_effects:
2100
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2101
- if return_fig:
2102
- if not show_leaf_effects:
2086
+ else:
2087
+ if save:
2088
+ tree2.render(save, tree_style=tree_style, units=units)
2089
+ if return_fig:
2103
2090
  return tree2, tree_style
2091
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2092
+
2104
2093
  return None
2105
2094
 
2106
2095
  @_doc_params(common_plot_args=doc_common_plot_args)
2107
- def plot_effects_umap( # pragma: no cover
2096
+ def plot_effects_umap( # pragma: no cover # noqa: D417
2108
2097
  self,
2109
2098
  mdata: MuData,
2110
2099
  effect_name: str | list | None,
@@ -2115,7 +2104,6 @@ class CompositionalModel2(ABC):
2115
2104
  color_map: Colormap | str | None = None,
2116
2105
  palette: str | Sequence[str] | None = None,
2117
2106
  ax: Axes = None,
2118
- show: bool = True,
2119
2107
  return_fig: bool = False,
2120
2108
  **kwargs,
2121
2109
  ) -> Figure | None:
@@ -2209,17 +2197,16 @@ class CompositionalModel2(ABC):
2209
2197
  **kwargs,
2210
2198
  )
2211
2199
 
2212
- if show:
2213
- plt.show()
2214
2200
  if return_fig:
2215
2201
  return fig
2202
+ plt.show()
2216
2203
  return None
2217
2204
 
2218
2205
 
2219
2206
  def get_a(
2220
2207
  tree: tt.core.ToyTree,
2221
2208
  ) -> tuple[np.ndarray, int]:
2222
- """Calculate ancestor matrix from a toytree tree
2209
+ """Calculate ancestor matrix from a toytree tree.
2223
2210
 
2224
2211
  Args:
2225
2212
  tree: A toytree tree object.
@@ -2272,16 +2259,14 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2272
2259
  A_T = A.T
2273
2260
  unq, count = np.unique(A_T, axis=0, return_counts=True)
2274
2261
 
2275
- repeated_idx = []
2276
- for repeated_group in unq[count > 1]:
2277
- repeated_idx.append(np.argwhere(np.all(A_T == repeated_group, axis=1)).ravel())
2262
+ repeated_idx = [np.argwhere(np.all(repeated_group == A_T, axis=1)).ravel() for repeated_group in unq[count > 1]]
2278
2263
 
2279
2264
  nodes_to_delete = [i for idx in repeated_idx for i in idx[1:]]
2280
2265
 
2281
2266
  # _coords.update() scrambles the idx of leaves. Therefore, keep track of it here
2282
2267
  tree_new = tree.copy()
2283
2268
  for node in tree_new.treenode.traverse():
2284
- node.add_feature("idx_orig", node.idx)
2269
+ node.add_prop("idx_orig", node.idx)
2285
2270
 
2286
2271
  for n in nodes_to_delete:
2287
2272
  node = tree_new.idx_dict[n]
@@ -2297,21 +2282,16 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2297
2282
  return tree_new
2298
2283
 
2299
2284
 
2300
- def traverse(df_, a, i, innerl):
2301
- """
2302
- Helper function for df2newick
2303
- Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram
2285
+ def traverse(df_: pd.DataFrame, a: str, i: int, innerl: bool) -> str:
2286
+ """Helper function for df2newick.
2287
+
2288
+ Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram.
2304
2289
  """
2305
2290
  if i + 1 < df_.shape[1]:
2306
2291
  a_inner = pd.unique(df_.loc[np.where(df_.iloc[:, i] == a)].iloc[:, i + 1])
2307
2292
 
2308
- desc = []
2309
- for b in a_inner:
2310
- desc.append(traverse(df_, b, i + 1, innerl))
2311
- if innerl:
2312
- il = a
2313
- else:
2314
- il = ""
2293
+ desc = [traverse(df_, b, i + 1, innerl) for b in a_inner]
2294
+ il = a if innerl else ""
2315
2295
  out = f"({','.join(desc)}){il}"
2316
2296
  else:
2317
2297
  out = a
@@ -2335,9 +2315,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
2335
2315
  df_tax = df.loc[:, [x for x in levels if x in df.columns]]
2336
2316
 
2337
2317
  alevel = pd.unique(df_tax.iloc[:, 0])
2338
- strs = []
2339
- for a in alevel:
2340
- strs.append(traverse(df_tax, a, 0, inner_label))
2318
+ strs = [traverse(df_tax, a, 0, inner_label) for a in alevel]
2341
2319
 
2342
2320
  newick = f"({','.join(strs)});"
2343
2321
  return newick
@@ -2348,10 +2326,10 @@ def get_a_2(
2348
2326
  leaf_order: list[str] = None,
2349
2327
  node_order: list[str] = None,
2350
2328
  ) -> tuple[np.ndarray, int]:
2351
- """Calculate ancestor matrix from a ete3 tree.
2329
+ """Calculate ancestor matrix from a ete4 tree.
2352
2330
 
2353
2331
  Args:
2354
- tree: A ete3 tree object.
2332
+ tree: A ete4 tree object.
2355
2333
  leaf_order: List of leaf names how they should appear as the rows of the ancestor matrix.
2356
2334
  If None, the ordering will be as in `tree.iter_leaves()`
2357
2335
  node_order: List of node names how they should appear as the columns of the ancestor matrix
@@ -2366,29 +2344,29 @@ def get_a_2(
2366
2344
  number of nodes in the tree, excluding the root node
2367
2345
  """
2368
2346
  try:
2369
- import ete3 as ete
2347
+ import ete4 as ete
2370
2348
  except ImportError:
2371
2349
  raise ImportError(
2372
2350
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2373
2351
  ) from None
2374
2352
 
2375
- n_tips = len(tree.get_leaves())
2376
- n_nodes = len(tree.get_descendants())
2353
+ n_tips = len(list(tree.leaves()))
2354
+ n_nodes = len(list(tree.descendants()))
2377
2355
 
2378
- node_names = [n.name for n in tree.iter_descendants()]
2356
+ node_names = [n.name for n in tree.descendants()]
2379
2357
  duplicates = [x for x in node_names if node_names.count(x) > 1]
2380
2358
  if len(duplicates) > 0:
2381
2359
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")
2382
2360
 
2383
2361
  # Initialize ancestor matrix
2384
2362
  A_ = pd.DataFrame(np.zeros((n_tips, n_nodes)))
2385
- A_.index = tree.get_leaf_names()
2386
- A_.columns = [n.name for n in tree.iter_descendants()]
2363
+ A_.index = tree.leaf_names()
2364
+ A_.columns = [n.name for n in tree.descendants()]
2387
2365
 
2388
2366
  # Fill in 1's for all connections
2389
- for node in tree.iter_descendants():
2390
- for leaf in tree.get_leaves():
2391
- if leaf in node.get_leaves():
2367
+ for node in tree.descendants():
2368
+ for leaf in tree.leaves():
2369
+ if leaf in node.leaves():
2392
2370
  A_.loc[leaf.name, node.name] = 1
2393
2371
 
2394
2372
  # Order rows and columns
@@ -2402,15 +2380,15 @@ def get_a_2(
2402
2380
 
2403
2381
 
2404
2382
  def collapse_singularities_2(tree: Tree) -> Tree:
2405
- """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
2383
+ """Collapses (deletes) nodes in a ete4 tree that are singularities (have only one child).
2406
2384
 
2407
2385
  Args:
2408
- tree: A ete3 tree object
2386
+ tree: A ete4 tree object
2409
2387
 
2410
2388
  Returns:
2411
- A ete3 tree without singularities.
2389
+ A ete4 tree without singularities.
2412
2390
  """
2413
- for node in tree.iter_descendants():
2391
+ for node in tree.descendants():
2414
2392
  if len(node.get_children()) == 1:
2415
2393
  node.delete()
2416
2394
 
@@ -2435,13 +2413,10 @@ def linkage_to_newick(
2435
2413
  tree = sp_hierarchy.to_tree(Z, False)
2436
2414
 
2437
2415
  def build_newick(node, newick, parentdist, leaf_names):
2438
- if node.is_leaf():
2416
+ if node.is_leaf:
2439
2417
  return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
2440
2418
  else:
2441
- if len(newick) > 0:
2442
- newick = f"):{(parentdist - node.dist) / 2}{newick}"
2443
- else:
2444
- newick = ");"
2419
+ newick = f"):{(parentdist - node.dist) / 2}{newick}" if len(newick) > 0 else ");"
2445
2420
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
2446
2421
  newick = build_newick(node.get_right(), f",{newick}", node.dist, leaf_names)
2447
2422
  newick = f"({newick}"
@@ -2486,10 +2461,10 @@ def import_tree(
2486
2461
 
2487
2462
  See `key_added` parameter description for the storage path of tree.
2488
2463
 
2489
- tree: A ete3 tree object.
2464
+ tree: A ete4 tree object.
2490
2465
  """
2491
2466
  try:
2492
- import ete3 as ete
2467
+ import ete4 as ete
2493
2468
  except ImportError:
2494
2469
  raise ImportError(
2495
2470
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -2514,32 +2489,33 @@ def import_tree(
2514
2489
  data_1.uns["dendrogram_cell_label"]["linkage"],
2515
2490
  labels=data_1.uns["dendrogram_cell_label"]["categories_ordered"],
2516
2491
  )
2517
- tree = ete.Tree(newick, format=1)
2492
+ tree = ete.Tree(newick, parser=1)
2518
2493
  node_id = 0
2519
- for n in tree.iter_descendants():
2520
- if not n.is_leaf():
2494
+ for n in tree.descendants():
2495
+ if not n.is_leaf:
2521
2496
  n.name = str(node_id)
2522
2497
  node_id += 1
2523
2498
  elif levels_orig is not None:
2524
2499
  newick = df2newick(data_1.obs.reset_index(), levels=levels_orig)
2525
- tree = ete.Tree(newick, format=8)
2500
+ tree = ete.Tree(newick, parser=8)
2501
+
2526
2502
  if add_level_name:
2527
- for n in tree.iter_descendants():
2528
- if not n.is_leaf():
2529
- dist = n.get_distance(n, tree)
2503
+ for n in tree.descendants():
2504
+ if not n.is_leaf:
2505
+ dist = n.get_distance(n, tree, topological=True)
2530
2506
  n.name = f"{levels_orig[int(dist) - 1]}_{n.name}"
2531
2507
  elif levels_agg is not None:
2532
2508
  newick = df2newick(data_2.var.reset_index(), levels=levels_agg)
2533
- tree = ete.Tree(newick, format=8)
2509
+ tree = ete.Tree(newick, parser=8)
2534
2510
  if add_level_name:
2535
- for n in tree.iter_descendants():
2536
- if not n.is_leaf():
2537
- dist = n.get_distance(n, tree)
2511
+ for n in tree.descendants():
2512
+ if not n.is_leaf:
2513
+ dist = n.get_distance(n, tree, topological=True)
2538
2514
  n.name = f"{levels_agg[int(dist) - 1]}_{n.name}"
2539
2515
  else:
2540
2516
  raise ValueError("Either dendrogram_key, levels_orig or levels_agg must be specified!")
2541
2517
 
2542
- node_names = [n.name for n in tree.iter_descendants()]
2518
+ node_names = [n.name for n in tree.descendants()]
2543
2519
  duplicates = {x for x in node_names if node_names.count(x) > 1}
2544
2520
  if len(duplicates) > 0:
2545
2521
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")