pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
33
33
 
34
34
  import numpyro as npy
35
35
  import toytree as tt
36
- from ete3 import Tree
36
+ from ete4 import Tree
37
37
  from jax._src.typing import Array
38
38
  from matplotlib.axes import Axes
39
39
  from matplotlib.colors import Colormap
@@ -198,7 +198,7 @@ class CompositionalModel2(ABC):
198
198
  *args,
199
199
  **kwargs,
200
200
  ):
201
- """Background function that executes any numpyro MCMC algorithm and processes its results
201
+ """Background function that executes any numpyro MCMC algorithm and processes its results.
202
202
 
203
203
  Args:
204
204
  sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -294,6 +294,8 @@ class CompositionalModel2(ABC):
294
294
  num_warmup: Number of burn-in (warmup) samples.
295
295
  rng_key: The rng state used.
296
296
  copy: Return a copy instead of writing to adata.
297
+ *args: Additional args passed to numpyro NUTS
298
+ **kwargs: Additional kwargs passed to numpyro NUTS
297
299
 
298
300
  Returns:
299
301
  Calls `self.__run_mcmc`
@@ -347,6 +349,8 @@ class CompositionalModel2(ABC):
347
349
  num_warmup: Number of burn-in (warmup) samples.
348
350
  rng_key: The rng state used. If None, a random state will be selected.
349
351
  copy: Return a copy instead of writing to adata.
352
+ *args: Additional args passed to numpyro HMC
353
+ **kwargs: Additional kwargs passed to numpyro HMC
350
354
 
351
355
  Examples:
352
356
  >>> import pertpy as pt
@@ -396,7 +400,8 @@ class CompositionalModel2(ABC):
396
400
  self, sample_adata: AnnData, est_fdr: float = 0.05, *args, **kwargs
397
401
  ) -> tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
398
402
  """Generates summary dataframes for intercepts, effects and node-level effect (if using tree aggregation).
399
- This function builds on and supports all functionalities from ``az.summary``.
403
+
404
+ This function builds on and supports all functionalities from ``az.summary``.
400
405
 
401
406
  Args:
402
407
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -405,7 +410,7 @@ class CompositionalModel2(ABC):
405
410
  kwargs: Passed to ``az.summary``
406
411
 
407
412
  Returns:
408
- Tuple[pd.DataFrame, pd.DataFrame] or Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: Intercept, effect and node-level DataFrames
413
+ Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame] or Tuple[:class:pandas.DataFrame, :class:pandas.DataFrame, :class:pandas.DataFrame]: Intercept, effect and node-level DataFrames
409
414
 
410
415
  intercept_df
411
416
  Summary of intercept parameters. Contains one row per cell type.
@@ -435,7 +440,7 @@ class CompositionalModel2(ABC):
435
440
  - Delta: Decision boundary value - threshold of practical significance
436
441
  - Is credible: Boolean indicator whether effect is credible
437
442
 
438
- Examples:
443
+ Examples:
439
444
  >>> import pertpy as pt
440
445
  >>> haber_cells = pt.dt.haber_2017_regions()
441
446
  >>> sccoda = pt.tl.Sccoda()
@@ -684,7 +689,7 @@ class CompositionalModel2(ABC):
684
689
 
685
690
  if fdr < alpha:
686
691
  # ceiling with 3 decimals precision
687
- c = np.floor(c * 10**3) / 10**3
692
+ c = np.floor(c * 10**3) / 10**3 # noqa: PLW2901
688
693
  return c, fdr
689
694
  return 1.0, 0
690
695
 
@@ -737,7 +742,8 @@ class CompositionalModel2(ABC):
737
742
  node_df: pd.DataFrame,
738
743
  ) -> pd.DataFrame:
739
744
  """Evaluation of MCMC results for node-level effect parameters. This function is only used within self.summary_prepare.
740
- This function determines whether node-level effects are credible or not
745
+
746
+ This function determines whether node-level effects are credible or not.
741
747
 
742
748
  Args:
743
749
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
@@ -932,15 +938,15 @@ class CompositionalModel2(ABC):
932
938
  )
933
939
  console.print(table)
934
940
 
935
- def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda"):
936
- """Get intercept dataframe as printed in the extended summary
941
+ def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
942
+ """Get intercept dataframe as printed in the extended summary.
937
943
 
938
944
  Args:
939
945
  data: AnnData object or MuData object.
940
946
  modality_key: If data is a MuData object, specify which modality to use.
941
947
 
942
948
  Returns:
943
- pd.DataFrame: Intercept data frame.
949
+ Intercept data frame.
944
950
 
945
951
  Examples:
946
952
  >>> import pertpy as pt
@@ -963,15 +969,15 @@ class CompositionalModel2(ABC):
963
969
 
964
970
  return sample_adata.varm["intercept_df"]
965
971
 
966
- def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda"):
967
- """Get effect dataframe as printed in the extended summary
972
+ def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
973
+ """Get effect dataframe as printed in the extended summary.
968
974
 
969
975
  Args:
970
976
  data: AnnData object or MuData object.
971
977
  modality_key: If data is a MuData object, specify which modality to use.
972
978
 
973
979
  Returns:
974
- pd.DataFrame: Effect data frame.
980
+ Effect data frame.
975
981
 
976
982
  Examples:
977
983
  >>> import pertpy as pt
@@ -1005,15 +1011,15 @@ class CompositionalModel2(ABC):
1005
1011
 
1006
1012
  return effect_df
1007
1013
 
1008
- def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"):
1009
- """Get node effect dataframe as printed in the extended summary of a tascCODA model
1014
+ def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda") -> pd.DataFrame:
1015
+ """Get node effect dataframe as printed in the extended summary of a tascCODA model.
1010
1016
 
1011
1017
  Args:
1012
1018
  data: AnnData object or MuData object.
1013
1019
  modality_key: If data is a MuData object, specify which modality to use.
1014
1020
 
1015
1021
  Returns:
1016
- pd.DataFrame: Node effect data frame.
1022
+ Node effect data frame.
1017
1023
 
1018
1024
  Examples:
1019
1025
  >>> import pertpy as pt
@@ -1030,7 +1036,6 @@ class CompositionalModel2(ABC):
1030
1036
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1031
1037
  >>> node_effects = tasccoda.get_node_df(mdata)
1032
1038
  """
1033
-
1034
1039
  if isinstance(data, MuData):
1035
1040
  try:
1036
1041
  sample_adata = data[modality_key]
@@ -1043,8 +1048,9 @@ class CompositionalModel2(ABC):
1043
1048
  return sample_adata.uns["scCODA_params"]["node_df"]
1044
1049
 
1045
1050
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
1046
- """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level
1047
- Note: Does not work for spike-and-slab LASSO selection method
1051
+ """Direct posterior probability approach to calculate credible effects while keeping the expected FDR at a certain level.
1052
+
1053
+ Note: Does not work for spike-and-slab LASSO selection method.
1048
1054
 
1049
1055
  Args:
1050
1056
  data: AnnData object or MuData object.
@@ -1079,7 +1085,8 @@ class CompositionalModel2(ABC):
1079
1085
 
1080
1086
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
1081
1087
  """Decides which effects of the scCODA model are credible based on an adjustable inclusion probability threshold.
1082
- Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method
1088
+
1089
+ Note: Parameter est_fdr has no effect for spike-and-slab LASSO selection method.
1083
1090
 
1084
1091
  Args:
1085
1092
  data: AnnData object or MuData object.
@@ -1087,7 +1094,7 @@ class CompositionalModel2(ABC):
1087
1094
  est_fdr: Estimated false discovery rate. Must be between 0 and 1.
1088
1095
 
1089
1096
  Returns:
1090
- pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1097
+ Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
1091
1098
  """
1092
1099
  if isinstance(data, MuData):
1093
1100
  try:
@@ -1109,16 +1116,15 @@ class CompositionalModel2(ABC):
1109
1116
  else:
1110
1117
  _, eff_df = self.summary_prepare(sample_adata, est_fdr=est_fdr) # type: ignore
1111
1118
  # otherwise, get pre-calculated DataFrames. Effect DataFrame is stitched together from varm
1119
+ elif model_type == "tree_agg" and select_type == "sslasso":
1120
+ eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1112
1121
  else:
1113
- if model_type == "tree_agg" and select_type == "sslasso":
1114
- eff_df = sample_adata.uns["scCODA_params"]["node_df"]
1115
- else:
1116
- covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1117
- effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1118
- eff_df = pd.concat(effect_dfs)
1119
- eff_df.index = pd.MultiIndex.from_product(
1120
- (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1121
- )
1122
+ covariates = sample_adata.uns["scCODA_params"]["covariate_names"]
1123
+ effect_dfs = [sample_adata.varm[f"effect_df_{cov}"] for cov in covariates]
1124
+ eff_df = pd.concat(effect_dfs)
1125
+ eff_df.index = pd.MultiIndex.from_product(
1126
+ (covariates, sample_adata.var.index.tolist()), names=["Covariate", "Cell Type"]
1127
+ )
1122
1128
 
1123
1129
  out = eff_df["Final Parameter"] != 0
1124
1130
  out.rename("credible change")
@@ -1188,7 +1194,7 @@ class CompositionalModel2(ABC):
1188
1194
  return ax
1189
1195
 
1190
1196
  @_doc_params(common_plot_args=doc_common_plot_args)
1191
- def plot_stacked_barplot( # pragma: no cover
1197
+ def plot_stacked_barplot( # pragma: no cover # noqa: D417
1192
1198
  self,
1193
1199
  data: AnnData | MuData,
1194
1200
  feature_name: str,
@@ -1215,7 +1221,7 @@ class CompositionalModel2(ABC):
1215
1221
  {common_plot_args}
1216
1222
 
1217
1223
  Returns:
1218
- If `return_fig` is `True`, returns the figure, otherwise `None`.
1224
+ If `return_fig` is `True`, returns the Figure, otherwise `None`.
1219
1225
 
1220
1226
  Examples:
1221
1227
  >>> import pertpy as pt
@@ -1230,8 +1236,6 @@ class CompositionalModel2(ABC):
1230
1236
  """
1231
1237
  if isinstance(data, MuData):
1232
1238
  data = data[modality_key]
1233
- if isinstance(data, AnnData):
1234
- data = data
1235
1239
 
1236
1240
  ct_names = data.var.index
1237
1241
 
@@ -1283,7 +1287,7 @@ class CompositionalModel2(ABC):
1283
1287
  return None
1284
1288
 
1285
1289
  @_doc_params(common_plot_args=doc_common_plot_args)
1286
- def plot_effects_barplot( # pragma: no cover
1290
+ def plot_effects_barplot( # pragma: no cover # noqa: D417
1287
1291
  self,
1288
1292
  data: AnnData | MuData,
1289
1293
  *,
@@ -1340,8 +1344,7 @@ class CompositionalModel2(ABC):
1340
1344
  args_barplot = {}
1341
1345
  if isinstance(data, MuData):
1342
1346
  data = data[modality_key]
1343
- if isinstance(data, AnnData):
1344
- data = data
1347
+
1345
1348
  # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1346
1349
  covariate_names = data.uns["scCODA_params"]["covariate_names"]
1347
1350
  if covariates is not None:
@@ -1372,18 +1375,16 @@ class CompositionalModel2(ABC):
1372
1375
 
1373
1376
  plot_df = plot_df.reset_index()
1374
1377
 
1375
- if len(covariate_names_zero) != 0:
1376
- if plot_facets:
1377
- if plot_zero_covariate and not plot_zero_cell_type:
1378
- for covariate_name_zero in covariate_names_zero:
1379
- new_row = {
1380
- "Covariate": covariate_name_zero,
1381
- "Cell Type": "zero",
1382
- "value": 0,
1383
- }
1384
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1385
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1386
- plot_df = plot_df.sort_values(["covariate_"])
1378
+ if len(covariate_names_zero) != 0 and plot_facets and plot_zero_covariate and not plot_zero_cell_type:
1379
+ for covariate_name_zero in covariate_names_zero:
1380
+ new_row = {
1381
+ "Covariate": covariate_name_zero,
1382
+ "Cell Type": "zero",
1383
+ "value": 0,
1384
+ }
1385
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1386
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1387
+ plot_df = plot_df.sort_values(["covariate_"])
1387
1388
  if not plot_zero_cell_type:
1388
1389
  cell_type_names_zero = [
1389
1390
  name
@@ -1427,9 +1428,8 @@ class CompositionalModel2(ABC):
1427
1428
  ax.set_title(covariate_names[i])
1428
1429
  if len(ax.get_xticklabels()) < 5:
1429
1430
  ax.set_aspect(10 / len(ax.get_xticklabels()))
1430
- if len(ax.get_xticklabels()) == 1:
1431
- if ax.get_xticklabels()[0]._text == "zero":
1432
- ax.set_xticks([])
1431
+ if len(ax.get_xticklabels()) == 1 and ax.get_xticklabels()[0]._text == "zero":
1432
+ ax.set_xticks([])
1433
1433
 
1434
1434
  # If not plot as facets, call barplot to plot cell types on the x-axis.
1435
1435
  else:
@@ -1460,6 +1460,7 @@ class CompositionalModel2(ABC):
1460
1460
  ax=ax,
1461
1461
  )
1462
1462
  cell_types = pd.unique(plot_df["Cell Type"])
1463
+ ax.set_xticks(cell_types)
1463
1464
  ax.set_xticklabels(cell_types, rotation=90)
1464
1465
 
1465
1466
  if return_fig and plot_facets:
@@ -1470,7 +1471,7 @@ class CompositionalModel2(ABC):
1470
1471
  return None
1471
1472
 
1472
1473
  @_doc_params(common_plot_args=doc_common_plot_args)
1473
- def plot_boxplots( # pragma: no cover
1474
+ def plot_boxplots( # pragma: no cover # noqa: D417
1474
1475
  self,
1475
1476
  data: AnnData | MuData,
1476
1477
  feature_name: str,
@@ -1532,8 +1533,7 @@ class CompositionalModel2(ABC):
1532
1533
  args_swarmplot = {}
1533
1534
  if isinstance(data, MuData):
1534
1535
  data = data[modality_key]
1535
- if isinstance(data, AnnData):
1536
- data = data
1536
+
1537
1537
  # y scale transformations
1538
1538
  if y_scale == "relative":
1539
1539
  sample_sums = np.sum(data.X, axis=1, keepdims=True)
@@ -1607,10 +1607,7 @@ class CompositionalModel2(ABC):
1607
1607
  )
1608
1608
 
1609
1609
  if add_dots:
1610
- if "hue" in args_swarmplot:
1611
- hue = args_swarmplot.pop("hue")
1612
- else:
1613
- hue = None
1610
+ hue = args_swarmplot.pop("hue") if "hue" in args_swarmplot else None
1614
1611
 
1615
1612
  if hue is None:
1616
1613
  g.map(
@@ -1675,6 +1672,7 @@ class CompositionalModel2(ABC):
1675
1672
  )
1676
1673
 
1677
1674
  cell_types = pd.unique(plot_df["Cell type"])
1675
+ ax.set_xticks(cell_types)
1678
1676
  ax.set_xticklabels(cell_types, rotation=90)
1679
1677
 
1680
1678
  if show_legend:
@@ -1702,7 +1700,7 @@ class CompositionalModel2(ABC):
1702
1700
  return None
1703
1701
 
1704
1702
  @_doc_params(common_plot_args=doc_common_plot_args)
1705
- def plot_rel_abundance_dispersion_plot( # pragma: no cover
1703
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover # noqa: D417
1706
1704
  self,
1707
1705
  data: AnnData | MuData,
1708
1706
  *,
@@ -1750,8 +1748,7 @@ class CompositionalModel2(ABC):
1750
1748
  """
1751
1749
  if isinstance(data, MuData):
1752
1750
  data = data[modality_key]
1753
- if isinstance(data, AnnData):
1754
- data = data
1751
+
1755
1752
  if ax is None:
1756
1753
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1757
1754
 
@@ -1823,13 +1820,13 @@ class CompositionalModel2(ABC):
1823
1820
  return None
1824
1821
 
1825
1822
  @_doc_params(common_plot_args=doc_common_plot_args)
1826
- def plot_draw_tree( # pragma: no cover
1823
+ def plot_draw_tree( # pragma: no cover # noqa: D417
1827
1824
  self,
1828
1825
  data: AnnData | MuData,
1829
1826
  *,
1830
1827
  modality_key: str = "coda",
1831
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1832
- tight_text: bool | None = False,
1828
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1829
+ tight_text: bool = False,
1833
1830
  show_scale: bool | None = False,
1834
1831
  units: Literal["px", "mm", "in"] | None = "px",
1835
1832
  figsize: tuple[float, float] | None = (None, None),
@@ -1837,12 +1834,12 @@ class CompositionalModel2(ABC):
1837
1834
  save: str | bool = False,
1838
1835
  return_fig: bool = False,
1839
1836
  ) -> Tree | None:
1840
- """Plot a tree using input ete3 tree object.
1837
+ """Plot a tree using input ete4 tree object.
1841
1838
 
1842
1839
  Args:
1843
1840
  data: AnnData object or MuData object.
1844
1841
  modality_key: If data is a MuData object, specify which modality to use.
1845
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1842
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1846
1843
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
1847
1844
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1848
1845
  show_scale: Include the scale legend in the tree image or not.
@@ -1853,7 +1850,7 @@ class CompositionalModel2(ABC):
1853
1850
  {common_plot_args}
1854
1851
 
1855
1852
  Returns:
1856
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1853
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`) or plot the tree inline (`save = False`)
1857
1854
 
1858
1855
  Examples:
1859
1856
  >>> import pertpy as pt
@@ -1874,7 +1871,8 @@ class CompositionalModel2(ABC):
1874
1871
  .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1875
1872
  """
1876
1873
  try:
1877
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1874
+ from ete4 import Tree
1875
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1878
1876
  except ImportError:
1879
1877
  raise ImportError(
1880
1878
  "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
@@ -1882,8 +1880,6 @@ class CompositionalModel2(ABC):
1882
1880
 
1883
1881
  if isinstance(data, MuData):
1884
1882
  data = data[modality_key]
1885
- if isinstance(data, AnnData):
1886
- data = data
1887
1883
  if isinstance(tree, str):
1888
1884
  tree = data.uns[tree]
1889
1885
 
@@ -1896,7 +1892,7 @@ class CompositionalModel2(ABC):
1896
1892
  tree_style.layout_fn = my_layout
1897
1893
  tree_style.show_scale = show_scale
1898
1894
 
1899
- if save is not None:
1895
+ if save:
1900
1896
  tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1901
1897
  if return_fig:
1902
1898
  return tree, tree_style
@@ -1904,13 +1900,13 @@ class CompositionalModel2(ABC):
1904
1900
  return None
1905
1901
 
1906
1902
  @_doc_params(common_plot_args=doc_common_plot_args)
1907
- def plot_draw_effects( # pragma: no cover
1903
+ def plot_draw_effects( # pragma: no cover # noqa: D417
1908
1904
  self,
1909
1905
  data: AnnData | MuData,
1910
1906
  covariate: str,
1911
1907
  *,
1912
1908
  modality_key: str = "coda",
1913
- tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1909
+ tree: str = "tree", # Also type ete4.Tree. Omitted due to import errors
1914
1910
  show_legend: bool | None = None,
1915
1911
  show_leaf_effects: bool | None = False,
1916
1912
  tight_text: bool | None = False,
@@ -1927,7 +1923,7 @@ class CompositionalModel2(ABC):
1927
1923
  data: AnnData object or MuData object.
1928
1924
  covariate: The covariate, whose effects should be plotted.
1929
1925
  modality_key: If data is a MuData object, specify which modality to use.
1930
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1926
+ tree: A ete4 tree object or a str to indicate the tree stored in `.uns`.
1931
1927
  show_legend: If show legend of nodes significant effects or not.
1932
1928
  Defaults to False if show_leaf_effects is True.
1933
1929
  show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
@@ -1941,8 +1937,8 @@ class CompositionalModel2(ABC):
1941
1937
  {common_plot_args}
1942
1938
 
1943
1939
  Returns:
1944
- Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`)
1945
- or plot the tree inline (`show = True`)
1940
+ Depending on `save`, returns :class:`ete4.core.tree.Tree` and :class:`ete4.treeview.TreeStyle` (`save = 'output.png'`)
1941
+ or plot the tree inline (`save = False`).
1946
1942
 
1947
1943
  Examples:
1948
1944
  >>> import pertpy as pt
@@ -1963,7 +1959,8 @@ class CompositionalModel2(ABC):
1963
1959
  .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
1964
1960
  """
1965
1961
  try:
1966
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1962
+ from ete4 import Tree
1963
+ from ete4.treeview import CircleFace, NodeStyle, TextFace, TreeStyle, faces
1967
1964
  except ImportError:
1968
1965
  raise ImportError(
1969
1966
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -1971,8 +1968,6 @@ class CompositionalModel2(ABC):
1971
1968
 
1972
1969
  if isinstance(data, MuData):
1973
1970
  data = data[modality_key]
1974
- if isinstance(data, AnnData):
1975
- data = data
1976
1971
  if show_legend is None:
1977
1972
  show_legend = not show_leaf_effects
1978
1973
  elif show_legend:
@@ -2003,18 +1998,18 @@ class CompositionalModel2(ABC):
2003
1998
  n.set_style(nstyle)
2004
1999
  if n.name in node_effs.index:
2005
2000
  e = node_effs.loc[n.name, "Final Parameter"]
2006
- n.add_feature("node_effect", e)
2001
+ n.add_prop("node_effect", e)
2007
2002
  else:
2008
- n.add_feature("node_effect", 0)
2003
+ n.add_prop("node_effect", 0)
2009
2004
  if n.name in leaf_effs.index:
2010
2005
  e = leaf_effs.loc[n.name, "Effect"]
2011
- n.add_feature("leaf_effect", e)
2006
+ n.add_prop("leaf_effect", e)
2012
2007
  else:
2013
- n.add_feature("leaf_effect", 0)
2008
+ n.add_prop("leaf_effect", 0)
2014
2009
 
2015
2010
  # Scale effect values to get nice node sizes
2016
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2017
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2011
+ eff_max = np.max([np.abs(n.props.get("node_effect")) for n in tree2.traverse()])
2012
+ leaf_eff_max = np.max([np.abs(n.props.get("leaf_effect")) for n in tree2.traverse()])
2018
2013
 
2019
2014
  def my_layout(node):
2020
2015
  text_face = TextFace(node.name, tight_text=tight_text)
@@ -2022,10 +2017,10 @@ class CompositionalModel2(ABC):
2022
2017
  faces.add_face_to_node(text_face, node, column=0, aligned=True)
2023
2018
 
2024
2019
  # if node.is_leaf():
2025
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2026
- if np.sign(node.node_effect) == 1:
2020
+ size = (np.abs(node.props.get("node_effect")) * 10 / eff_max) if node.props.get("node_effect") != 0 else 0
2021
+ if np.sign(node.props.get("node_effect")) == 1:
2027
2022
  color = "blue"
2028
- elif np.sign(node.node_effect) == -1:
2023
+ elif np.sign(node.props.get("node_effect")) == -1:
2029
2024
  color = "red"
2030
2025
  else:
2031
2026
  color = "cyan"
@@ -2061,13 +2056,13 @@ class CompositionalModel2(ABC):
2061
2056
  tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2062
2057
 
2063
2058
  if show_leaf_effects:
2064
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2059
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf]
2065
2060
  leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2066
2061
  palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2067
2062
 
2068
2063
  dir_path = Path.cwd()
2069
2064
  dir_path = Path(dir_path / "tree_effect.png")
2070
- tree2.render(dir_path, tree_style=tree_style, units="in")
2065
+ tree2.render(dir_path.as_posix(), tree_style=tree_style, units="in")
2071
2066
  _, ax = plt.subplots(1, 2, figsize=(10, 10))
2072
2067
  sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2073
2068
  img = mpimg.imread(dir_path)
@@ -2098,7 +2093,7 @@ class CompositionalModel2(ABC):
2098
2093
  return None
2099
2094
 
2100
2095
  @_doc_params(common_plot_args=doc_common_plot_args)
2101
- def plot_effects_umap( # pragma: no cover
2096
+ def plot_effects_umap( # pragma: no cover # noqa: D417
2102
2097
  self,
2103
2098
  mdata: MuData,
2104
2099
  effect_name: str | list | None,
@@ -2211,7 +2206,7 @@ class CompositionalModel2(ABC):
2211
2206
  def get_a(
2212
2207
  tree: tt.core.ToyTree,
2213
2208
  ) -> tuple[np.ndarray, int]:
2214
- """Calculate ancestor matrix from a toytree tree
2209
+ """Calculate ancestor matrix from a toytree tree.
2215
2210
 
2216
2211
  Args:
2217
2212
  tree: A toytree tree object.
@@ -2264,16 +2259,14 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2264
2259
  A_T = A.T
2265
2260
  unq, count = np.unique(A_T, axis=0, return_counts=True)
2266
2261
 
2267
- repeated_idx = []
2268
- for repeated_group in unq[count > 1]:
2269
- repeated_idx.append(np.argwhere(np.all(A_T == repeated_group, axis=1)).ravel())
2262
+ repeated_idx = [np.argwhere(np.all(repeated_group == A_T, axis=1)).ravel() for repeated_group in unq[count > 1]]
2270
2263
 
2271
2264
  nodes_to_delete = [i for idx in repeated_idx for i in idx[1:]]
2272
2265
 
2273
2266
  # _coords.update() scrambles the idx of leaves. Therefore, keep track of it here
2274
2267
  tree_new = tree.copy()
2275
2268
  for node in tree_new.treenode.traverse():
2276
- node.add_feature("idx_orig", node.idx)
2269
+ node.add_prop("idx_orig", node.idx)
2277
2270
 
2278
2271
  for n in nodes_to_delete:
2279
2272
  node = tree_new.idx_dict[n]
@@ -2289,21 +2282,16 @@ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
2289
2282
  return tree_new
2290
2283
 
2291
2284
 
2292
- def traverse(df_, a, i, innerl):
2293
- """
2294
- Helper function for df2newick
2295
- Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram
2285
+ def traverse(df_: pd.DataFrame, a: str, i: int, innerl: bool) -> str:
2286
+ """Helper function for df2newick.
2287
+
2288
+ Adapted from https://stackoverflow.com/questions/15343338/how-to-convert-a-data-frame-to-tree-structure-object-such-as-dendrogram.
2296
2289
  """
2297
2290
  if i + 1 < df_.shape[1]:
2298
2291
  a_inner = pd.unique(df_.loc[np.where(df_.iloc[:, i] == a)].iloc[:, i + 1])
2299
2292
 
2300
- desc = []
2301
- for b in a_inner:
2302
- desc.append(traverse(df_, b, i + 1, innerl))
2303
- if innerl:
2304
- il = a
2305
- else:
2306
- il = ""
2293
+ desc = [traverse(df_, b, i + 1, innerl) for b in a_inner]
2294
+ il = a if innerl else ""
2307
2295
  out = f"({','.join(desc)}){il}"
2308
2296
  else:
2309
2297
  out = a
@@ -2327,9 +2315,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
2327
2315
  df_tax = df.loc[:, [x for x in levels if x in df.columns]]
2328
2316
 
2329
2317
  alevel = pd.unique(df_tax.iloc[:, 0])
2330
- strs = []
2331
- for a in alevel:
2332
- strs.append(traverse(df_tax, a, 0, inner_label))
2318
+ strs = [traverse(df_tax, a, 0, inner_label) for a in alevel]
2333
2319
 
2334
2320
  newick = f"({','.join(strs)});"
2335
2321
  return newick
@@ -2340,10 +2326,10 @@ def get_a_2(
2340
2326
  leaf_order: list[str] = None,
2341
2327
  node_order: list[str] = None,
2342
2328
  ) -> tuple[np.ndarray, int]:
2343
- """Calculate ancestor matrix from a ete3 tree.
2329
+ """Calculate ancestor matrix from a ete4 tree.
2344
2330
 
2345
2331
  Args:
2346
- tree: A ete3 tree object.
2332
+ tree: A ete4 tree object.
2347
2333
  leaf_order: List of leaf names how they should appear as the rows of the ancestor matrix.
2348
2334
  If None, the ordering will be as in `tree.iter_leaves()`
2349
2335
  node_order: List of node names how they should appear as the columns of the ancestor matrix
@@ -2358,29 +2344,29 @@ def get_a_2(
2358
2344
  number of nodes in the tree, excluding the root node
2359
2345
  """
2360
2346
  try:
2361
- import ete3 as ete
2347
+ import ete4 as ete
2362
2348
  except ImportError:
2363
2349
  raise ImportError(
2364
2350
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2365
2351
  ) from None
2366
2352
 
2367
- n_tips = len(tree.get_leaves())
2368
- n_nodes = len(tree.get_descendants())
2353
+ n_tips = len(list(tree.leaves()))
2354
+ n_nodes = len(list(tree.descendants()))
2369
2355
 
2370
- node_names = [n.name for n in tree.iter_descendants()]
2356
+ node_names = [n.name for n in tree.descendants()]
2371
2357
  duplicates = [x for x in node_names if node_names.count(x) > 1]
2372
2358
  if len(duplicates) > 0:
2373
2359
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")
2374
2360
 
2375
2361
  # Initialize ancestor matrix
2376
2362
  A_ = pd.DataFrame(np.zeros((n_tips, n_nodes)))
2377
- A_.index = tree.get_leaf_names()
2378
- A_.columns = [n.name for n in tree.iter_descendants()]
2363
+ A_.index = tree.leaf_names()
2364
+ A_.columns = [n.name for n in tree.descendants()]
2379
2365
 
2380
2366
  # Fill in 1's for all connections
2381
- for node in tree.iter_descendants():
2382
- for leaf in tree.get_leaves():
2383
- if leaf in node.get_leaves():
2367
+ for node in tree.descendants():
2368
+ for leaf in tree.leaves():
2369
+ if leaf in node.leaves():
2384
2370
  A_.loc[leaf.name, node.name] = 1
2385
2371
 
2386
2372
  # Order rows and columns
@@ -2394,15 +2380,15 @@ def get_a_2(
2394
2380
 
2395
2381
 
2396
2382
  def collapse_singularities_2(tree: Tree) -> Tree:
2397
- """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
2383
+ """Collapses (deletes) nodes in a ete4 tree that are singularities (have only one child).
2398
2384
 
2399
2385
  Args:
2400
- tree: A ete3 tree object
2386
+ tree: A ete4 tree object
2401
2387
 
2402
2388
  Returns:
2403
- A ete3 tree without singularities.
2389
+ A ete4 tree without singularities.
2404
2390
  """
2405
- for node in tree.iter_descendants():
2391
+ for node in tree.descendants():
2406
2392
  if len(node.get_children()) == 1:
2407
2393
  node.delete()
2408
2394
 
@@ -2427,13 +2413,10 @@ def linkage_to_newick(
2427
2413
  tree = sp_hierarchy.to_tree(Z, False)
2428
2414
 
2429
2415
  def build_newick(node, newick, parentdist, leaf_names):
2430
- if node.is_leaf():
2416
+ if node.is_leaf:
2431
2417
  return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
2432
2418
  else:
2433
- if len(newick) > 0:
2434
- newick = f"):{(parentdist - node.dist) / 2}{newick}"
2435
- else:
2436
- newick = ");"
2419
+ newick = f"):{(parentdist - node.dist) / 2}{newick}" if len(newick) > 0 else ");"
2437
2420
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
2438
2421
  newick = build_newick(node.get_right(), f",{newick}", node.dist, leaf_names)
2439
2422
  newick = f"({newick}"
@@ -2478,10 +2461,10 @@ def import_tree(
2478
2461
 
2479
2462
  See `key_added` parameter description for the storage path of tree.
2480
2463
 
2481
- tree: A ete3 tree object.
2464
+ tree: A ete4 tree object.
2482
2465
  """
2483
2466
  try:
2484
- import ete3 as ete
2467
+ import ete4 as ete
2485
2468
  except ImportError:
2486
2469
  raise ImportError(
2487
2470
  "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
@@ -2506,32 +2489,33 @@ def import_tree(
2506
2489
  data_1.uns["dendrogram_cell_label"]["linkage"],
2507
2490
  labels=data_1.uns["dendrogram_cell_label"]["categories_ordered"],
2508
2491
  )
2509
- tree = ete.Tree(newick, format=1)
2492
+ tree = ete.Tree(newick, parser=1)
2510
2493
  node_id = 0
2511
- for n in tree.iter_descendants():
2512
- if not n.is_leaf():
2494
+ for n in tree.descendants():
2495
+ if not n.is_leaf:
2513
2496
  n.name = str(node_id)
2514
2497
  node_id += 1
2515
2498
  elif levels_orig is not None:
2516
2499
  newick = df2newick(data_1.obs.reset_index(), levels=levels_orig)
2517
- tree = ete.Tree(newick, format=8)
2500
+ tree = ete.Tree(newick, parser=8)
2501
+
2518
2502
  if add_level_name:
2519
- for n in tree.iter_descendants():
2520
- if not n.is_leaf():
2521
- dist = n.get_distance(n, tree)
2503
+ for n in tree.descendants():
2504
+ if not n.is_leaf:
2505
+ dist = n.get_distance(n, tree, topological=True)
2522
2506
  n.name = f"{levels_orig[int(dist) - 1]}_{n.name}"
2523
2507
  elif levels_agg is not None:
2524
2508
  newick = df2newick(data_2.var.reset_index(), levels=levels_agg)
2525
- tree = ete.Tree(newick, format=8)
2509
+ tree = ete.Tree(newick, parser=8)
2526
2510
  if add_level_name:
2527
- for n in tree.iter_descendants():
2528
- if not n.is_leaf():
2529
- dist = n.get_distance(n, tree)
2511
+ for n in tree.descendants():
2512
+ if not n.is_leaf:
2513
+ dist = n.get_distance(n, tree, topological=True)
2530
2514
  n.name = f"{levels_agg[int(dist) - 1]}_{n.name}"
2531
2515
  else:
2532
2516
  raise ValueError("Either dendrogram_key, levels_orig or levels_agg must be specified!")
2533
2517
 
2534
- node_names = [n.name for n in tree.iter_descendants()]
2518
+ node_names = [n.name for n in tree.descendants()]
2535
2519
  duplicates = {x for x in node_names if node_names.count(x) > 1}
2536
2520
  if len(duplicates) > 0:
2537
2521
  raise ValueError(f"Tree nodes have duplicate names: {duplicates}. Make sure that node names are unique!")