phylogenie 2.1.30__py3-none-any.whl → 3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +35 -28
- phylogenie/draw.py +22 -29
- phylogenie/generators/alisim.py +7 -8
- phylogenie/generators/configs.py +23 -1
- phylogenie/generators/dataset.py +1 -1
- phylogenie/generators/factories.py +10 -13
- phylogenie/generators/trees.py +19 -10
- phylogenie/io/__init__.py +1 -3
- phylogenie/mixins.py +41 -0
- phylogenie/treesimulator/__init__.py +29 -3
- phylogenie/treesimulator/events/base.py +0 -3
- phylogenie/treesimulator/events/contact_tracing.py +8 -9
- phylogenie/treesimulator/events/mutations.py +7 -8
- phylogenie/treesimulator/features.py +3 -3
- phylogenie/treesimulator/gillespie.py +21 -37
- phylogenie/treesimulator/io/__init__.py +4 -0
- phylogenie/{io → treesimulator/io}/newick.py +3 -3
- phylogenie/{io → treesimulator/io}/nexus.py +2 -2
- phylogenie/treesimulator/model.py +7 -10
- phylogenie/{tree.py → treesimulator/tree.py} +110 -84
- phylogenie/treesimulator/utils.py +108 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/METADATA +1 -1
- phylogenie-3.1.1.dist-info/RECORD +40 -0
- phylogenie/models.py +0 -17
- phylogenie/utils.py +0 -176
- phylogenie-2.1.30.dist-info/RECORD +0 -39
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/LICENSE.txt +0 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/WHEEL +0 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
from numpy.random import Generator
|
|
6
6
|
|
|
7
|
-
from phylogenie.models import Distribution
|
|
8
7
|
from phylogenie.skyline import SkylineParameterLike
|
|
9
8
|
from phylogenie.treesimulator.events.base import Event, EventType
|
|
10
9
|
from phylogenie.treesimulator.events.contact_tracing import (
|
|
@@ -42,7 +41,7 @@ class Mutation(Event):
|
|
|
42
41
|
self,
|
|
43
42
|
state: str,
|
|
44
43
|
rate: SkylineParameterLike,
|
|
45
|
-
rate_scalers: dict[EventType,
|
|
44
|
+
rate_scalers: dict[EventType, Callable[[], float]],
|
|
46
45
|
rates_to_log: list[EventType] | None = None,
|
|
47
46
|
):
|
|
48
47
|
super().__init__(state, rate)
|
|
@@ -52,16 +51,16 @@ class Mutation(Event):
|
|
|
52
51
|
def apply(
|
|
53
52
|
self, model: Model, events: list[Event], time: float, rng: Generator
|
|
54
53
|
) -> dict[str, Any]:
|
|
55
|
-
if NEXT_MUTATION_ID not in model.
|
|
56
|
-
model
|
|
57
|
-
model
|
|
58
|
-
mutation_id = model
|
|
54
|
+
if NEXT_MUTATION_ID not in model.metadata:
|
|
55
|
+
model[NEXT_MUTATION_ID] = 0
|
|
56
|
+
model[NEXT_MUTATION_ID] += 1
|
|
57
|
+
mutation_id = model[NEXT_MUTATION_ID]
|
|
59
58
|
|
|
60
59
|
individual = self.draw_individual(model, rng)
|
|
61
60
|
model.migrate(individual, _get_mutated_state(mutation_id, self.state), time)
|
|
62
61
|
|
|
63
62
|
rate_scalers: dict[EventType, float] = {
|
|
64
|
-
target_type:
|
|
63
|
+
target_type: rate_scaler()
|
|
65
64
|
for target_type, rate_scaler in self.rate_scalers.items()
|
|
66
65
|
}
|
|
67
66
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
from enum import Enum
|
|
3
3
|
|
|
4
|
-
from phylogenie.tree import Tree
|
|
5
4
|
from phylogenie.treesimulator.events.mutations import get_mutation_id
|
|
6
5
|
from phylogenie.treesimulator.model import get_node_state
|
|
7
|
-
from phylogenie.
|
|
6
|
+
from phylogenie.treesimulator.tree import Tree
|
|
7
|
+
from phylogenie.treesimulator.utils import (
|
|
8
8
|
get_node_depth_levels,
|
|
9
9
|
get_node_depths,
|
|
10
10
|
get_node_height_levels,
|
|
@@ -46,4 +46,4 @@ def set_features(tree: Tree, features: Iterable[Feature]) -> None:
|
|
|
46
46
|
for feature in features:
|
|
47
47
|
feature_maps = FEATURES_EXTRACTORS[feature](tree)
|
|
48
48
|
for node in tree:
|
|
49
|
-
node
|
|
49
|
+
node[feature.value] = feature_maps[node]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from collections.abc import Iterable, Sequence
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import joblib
|
|
7
7
|
import numpy as np
|
|
@@ -9,26 +9,28 @@ import pandas as pd
|
|
|
9
9
|
from numpy.random import default_rng
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
|
-
from phylogenie.io import dump_newick
|
|
13
|
-
from phylogenie.tree import Tree
|
|
14
12
|
from phylogenie.treesimulator.events import Event
|
|
15
13
|
from phylogenie.treesimulator.features import Feature, set_features
|
|
14
|
+
from phylogenie.treesimulator.io import dump_newick
|
|
16
15
|
from phylogenie.treesimulator.model import Model
|
|
16
|
+
from phylogenie.treesimulator.tree import Tree
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def simulate_tree(
|
|
20
20
|
events: Sequence[Event],
|
|
21
|
-
|
|
22
|
-
max_tips: int | None = None,
|
|
21
|
+
n_tips: int | None = None,
|
|
23
22
|
max_time: float = np.inf,
|
|
24
23
|
init_state: str | None = None,
|
|
25
24
|
sampling_probability_at_present: float = 0.0,
|
|
26
25
|
seed: int | None = None,
|
|
27
26
|
timeout: float = np.inf,
|
|
27
|
+
acceptance_criterion: Callable[[Tree], bool] | None = None,
|
|
28
28
|
) -> tuple[Tree, dict[str, Any]]:
|
|
29
|
-
if max_time
|
|
29
|
+
if (max_time != np.inf) == (n_tips is not None):
|
|
30
|
+
raise ValueError("Exactly one of max_time or n_tips must be specified.")
|
|
31
|
+
if sampling_probability_at_present and max_time == np.inf:
|
|
30
32
|
raise ValueError(
|
|
31
|
-
"sampling_probability_at_present
|
|
33
|
+
"sampling_probability_at_present can only be used with max_time."
|
|
32
34
|
)
|
|
33
35
|
|
|
34
36
|
states = {e.state for e in events if e.state}
|
|
@@ -51,32 +53,12 @@ def simulate_tree(
|
|
|
51
53
|
change_times = sorted(set(t for e in events for t in e.rate.change_times))
|
|
52
54
|
next_change_time = change_times.pop(0) if change_times else np.inf
|
|
53
55
|
|
|
54
|
-
|
|
55
|
-
if max_tips is None:
|
|
56
|
-
raise ValueError("Either max_time or max_tips must be specified.")
|
|
57
|
-
target_n_tips = rng.integers(min_tips, max_tips + 1)
|
|
58
|
-
else:
|
|
59
|
-
target_n_tips = None
|
|
60
|
-
|
|
61
|
-
while current_time < max_time:
|
|
56
|
+
while current_time < max_time and (n_tips is None or model.n_sampled < n_tips):
|
|
62
57
|
if time.perf_counter() - start_clock > timeout:
|
|
63
58
|
raise TimeoutError("Simulation timed out.")
|
|
64
59
|
|
|
65
60
|
rates = [e.get_propensity(model, current_time) for e in run_events]
|
|
66
|
-
|
|
67
|
-
instantaneous_events = [e for e, r in zip(run_events, rates) if r == np.inf]
|
|
68
|
-
if instantaneous_events:
|
|
69
|
-
event = instantaneous_events[rng.integers(len(instantaneous_events))]
|
|
70
|
-
event.apply(model, run_events, current_time, rng)
|
|
71
|
-
continue
|
|
72
|
-
|
|
73
|
-
if (
|
|
74
|
-
not any(rates)
|
|
75
|
-
or max_tips is not None
|
|
76
|
-
and model.n_sampled >= max_tips
|
|
77
|
-
or target_n_tips is not None
|
|
78
|
-
and model.n_sampled >= target_n_tips
|
|
79
|
-
):
|
|
61
|
+
if not any(rates):
|
|
80
62
|
break
|
|
81
63
|
|
|
82
64
|
time_step = rng.exponential(1 / sum(rates))
|
|
@@ -95,22 +77,23 @@ def simulate_tree(
|
|
|
95
77
|
if event_metadata is not None:
|
|
96
78
|
metadata.update(event_metadata)
|
|
97
79
|
|
|
80
|
+
if current_time != max_time and model.n_sampled != n_tips:
|
|
81
|
+
continue
|
|
82
|
+
|
|
98
83
|
for individual in model.get_population():
|
|
99
84
|
if rng.random() < sampling_probability_at_present:
|
|
100
85
|
model.sample(individual, current_time, True)
|
|
101
86
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
return (model.get_sampled_tree(), metadata)
|
|
87
|
+
tree = model.get_sampled_tree()
|
|
88
|
+
if acceptance_criterion is None or acceptance_criterion(tree):
|
|
89
|
+
return (tree, metadata)
|
|
106
90
|
|
|
107
91
|
|
|
108
92
|
def generate_trees(
|
|
109
93
|
output_dir: str | Path,
|
|
110
94
|
n_trees: int,
|
|
111
95
|
events: Sequence[Event],
|
|
112
|
-
|
|
113
|
-
max_tips: int | None = None,
|
|
96
|
+
n_tips: int | None = None,
|
|
114
97
|
max_time: float = np.inf,
|
|
115
98
|
init_state: str | None = None,
|
|
116
99
|
sampling_probability_at_present: float = 0.0,
|
|
@@ -118,6 +101,7 @@ def generate_trees(
|
|
|
118
101
|
seed: int | None = None,
|
|
119
102
|
n_jobs: int = -1,
|
|
120
103
|
timeout: float = np.inf,
|
|
104
|
+
acceptance_criterion: Callable[[Tree], bool] | None = None,
|
|
121
105
|
) -> pd.DataFrame:
|
|
122
106
|
if isinstance(output_dir, str):
|
|
123
107
|
output_dir = Path(output_dir)
|
|
@@ -130,13 +114,13 @@ def generate_trees(
|
|
|
130
114
|
try:
|
|
131
115
|
tree, metadata = simulate_tree(
|
|
132
116
|
events=events,
|
|
133
|
-
|
|
134
|
-
max_tips=max_tips,
|
|
117
|
+
n_tips=n_tips,
|
|
135
118
|
max_time=max_time,
|
|
136
119
|
init_state=init_state,
|
|
137
120
|
sampling_probability_at_present=sampling_probability_at_present,
|
|
138
121
|
seed=seed,
|
|
139
122
|
timeout=timeout,
|
|
123
|
+
acceptance_criterion=acceptance_criterion,
|
|
140
124
|
)
|
|
141
125
|
metadata["file_id"] = i
|
|
142
126
|
if node_features is not None:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
-
from phylogenie.tree import Tree
|
|
4
|
+
from phylogenie.treesimulator.tree import Tree
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def parse_newick(newick: str, translations: dict[str, str] | None = None) -> Tree:
|
|
@@ -76,8 +76,8 @@ def load_newick(filepath: str | Path) -> Tree | list[Tree]:
|
|
|
76
76
|
def to_newick(tree: Tree) -> str:
|
|
77
77
|
children_newick = ",".join([to_newick(child) for child in tree.children])
|
|
78
78
|
newick = tree.name
|
|
79
|
-
if tree.
|
|
80
|
-
reprs = {k: repr(v).replace("'", '"') for k, v in tree.
|
|
79
|
+
if tree.metadata:
|
|
80
|
+
reprs = {k: repr(v).replace("'", '"') for k, v in tree.metadata.items()}
|
|
81
81
|
for k, r in reprs.items():
|
|
82
82
|
if "," in k or "=" in k or "]" in k:
|
|
83
83
|
raise ValueError(
|
|
@@ -2,8 +2,8 @@ import re
|
|
|
2
2
|
from collections.abc import Iterator
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
-
from phylogenie.io.newick import parse_newick
|
|
6
|
-
from phylogenie.tree import Tree
|
|
5
|
+
from phylogenie.treesimulator.io.newick import parse_newick
|
|
6
|
+
from phylogenie.treesimulator.tree import Tree
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def _parse_translate_block(lines: Iterator[str]) -> dict[str, str]:
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
4
3
|
|
|
5
|
-
from phylogenie.
|
|
4
|
+
from phylogenie.mixins import MetadataMixin
|
|
5
|
+
from phylogenie.treesimulator.tree import Tree
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
@@ -17,23 +17,22 @@ def _get_node_name(node_id: int, state: str) -> str:
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def get_node_state(node_name: str) -> str:
|
|
20
|
-
|
|
21
|
-
return node_name.split("|")[1]
|
|
22
|
-
except IndexError:
|
|
20
|
+
if "|" not in node_name:
|
|
23
21
|
raise ValueError(
|
|
24
22
|
f"Invalid node name: {node_name} (expected format 'id|state')."
|
|
25
23
|
)
|
|
24
|
+
return node_name.split("|")[-1]
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
class Model:
|
|
27
|
+
class Model(MetadataMixin):
|
|
29
28
|
def __init__(self, init_state: str):
|
|
29
|
+
super().__init__()
|
|
30
30
|
self._next_node_id = 0
|
|
31
31
|
self._next_individual_id = 0
|
|
32
32
|
self._population: dict[int, Individual] = {}
|
|
33
33
|
self._states: dict[str, set[int]] = defaultdict(set)
|
|
34
34
|
self._sampled: set[str] = set()
|
|
35
35
|
self._tree = self._get_new_individual(init_state).node
|
|
36
|
-
self.context: dict[str, Any] = {}
|
|
37
36
|
|
|
38
37
|
@property
|
|
39
38
|
def n_sampled(self) -> int:
|
|
@@ -110,9 +109,7 @@ class Model:
|
|
|
110
109
|
elif len(node.children) == 1:
|
|
111
110
|
(child,) = node.children
|
|
112
111
|
child.set_parent(node.parent)
|
|
113
|
-
|
|
114
|
-
assert node.branch_length is not None
|
|
115
|
-
child.branch_length += node.branch_length
|
|
112
|
+
child.branch_length += node.branch_length # pyright: ignore
|
|
116
113
|
if node.parent is None:
|
|
117
114
|
return child
|
|
118
115
|
else:
|
|
@@ -2,14 +2,21 @@ from collections import deque
|
|
|
2
2
|
from collections.abc import Callable, Iterator
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from phylogenie.mixins import MetadataMixin
|
|
5
6
|
|
|
6
|
-
|
|
7
|
+
|
|
8
|
+
class Tree(MetadataMixin):
|
|
7
9
|
def __init__(self, name: str = "", branch_length: float | None = None):
|
|
10
|
+
super().__init__()
|
|
8
11
|
self.name = name
|
|
9
12
|
self.branch_length = branch_length
|
|
10
13
|
self._parent: Tree | None = None
|
|
11
14
|
self._children: list[Tree] = []
|
|
12
|
-
|
|
15
|
+
|
|
16
|
+
# ----------------
|
|
17
|
+
# Basic properties
|
|
18
|
+
# ----------------
|
|
19
|
+
# Properties related to parent-child relationships.
|
|
13
20
|
|
|
14
21
|
@property
|
|
15
22
|
def children(self) -> tuple["Tree", ...]:
|
|
@@ -19,59 +26,6 @@ class Tree:
|
|
|
19
26
|
def parent(self) -> "Tree | None":
|
|
20
27
|
return self._parent
|
|
21
28
|
|
|
22
|
-
@property
|
|
23
|
-
def features(self) -> dict[str, Any]:
|
|
24
|
-
return self._features.copy()
|
|
25
|
-
|
|
26
|
-
@property
|
|
27
|
-
def depth(self) -> float:
|
|
28
|
-
if self.parent is None:
|
|
29
|
-
return 0 if self.branch_length is None else self.branch_length
|
|
30
|
-
if self.branch_length is None:
|
|
31
|
-
raise ValueError(f"Branch length of node {self.name} is not set.")
|
|
32
|
-
return self.parent.depth + self.branch_length
|
|
33
|
-
|
|
34
|
-
@property
|
|
35
|
-
def depth_level(self) -> int:
|
|
36
|
-
if self.parent is None:
|
|
37
|
-
return 0
|
|
38
|
-
return self.parent.depth_level + 1
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def height(self) -> float:
|
|
42
|
-
if self.is_leaf():
|
|
43
|
-
return 0.0
|
|
44
|
-
if any(child.branch_length is None for child in self.children):
|
|
45
|
-
raise ValueError(
|
|
46
|
-
f"Branch length of one or more children of node {self.name} is not set."
|
|
47
|
-
)
|
|
48
|
-
return max(
|
|
49
|
-
child.branch_length + child.height # pyright: ignore
|
|
50
|
-
for child in self.children
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
@property
|
|
54
|
-
def height_level(self) -> int:
|
|
55
|
-
if self.is_leaf():
|
|
56
|
-
return 0
|
|
57
|
-
return 1 + max(child.height_level for child in self.children)
|
|
58
|
-
|
|
59
|
-
@property
|
|
60
|
-
def n_leaves(self) -> int:
|
|
61
|
-
return len(self.get_leaves())
|
|
62
|
-
|
|
63
|
-
def set(self, key: str, value: Any) -> None:
|
|
64
|
-
self._features[key] = value
|
|
65
|
-
|
|
66
|
-
def update_features(self, features: dict[str, Any]) -> None:
|
|
67
|
-
self._features.update(features)
|
|
68
|
-
|
|
69
|
-
def get(self, key: str) -> Any:
|
|
70
|
-
return self._features[key]
|
|
71
|
-
|
|
72
|
-
def delete(self, key: str) -> None:
|
|
73
|
-
del self._features[key]
|
|
74
|
-
|
|
75
29
|
def add_child(self, child: "Tree") -> "Tree":
|
|
76
30
|
if child.parent is not None:
|
|
77
31
|
raise ValueError(f"Node {child.name} already has a parent.")
|
|
@@ -90,6 +44,53 @@ class Tree:
|
|
|
90
44
|
if parent is not None:
|
|
91
45
|
parent._children.append(self)
|
|
92
46
|
|
|
47
|
+
def is_leaf(self) -> bool:
|
|
48
|
+
return not self.children
|
|
49
|
+
|
|
50
|
+
def get_leaves(self) -> tuple["Tree", ...]:
|
|
51
|
+
return tuple(node for node in self if node.is_leaf())
|
|
52
|
+
|
|
53
|
+
def is_internal(self) -> bool:
|
|
54
|
+
return not self.is_leaf()
|
|
55
|
+
|
|
56
|
+
def get_internal_nodes(self) -> tuple["Tree", ...]:
|
|
57
|
+
return tuple(node for node in self if node.is_internal())
|
|
58
|
+
|
|
59
|
+
def is_binary(self) -> bool:
|
|
60
|
+
return all(len(node.children) in (0, 2) for node in self)
|
|
61
|
+
|
|
62
|
+
# --------------
|
|
63
|
+
# Tree traversal
|
|
64
|
+
# --------------
|
|
65
|
+
# Methods for traversing the tree in various orders.
|
|
66
|
+
|
|
67
|
+
def iter_ancestors(self, stop: "Tree | None" = None) -> Iterator["Tree"]:
|
|
68
|
+
node = self
|
|
69
|
+
while True:
|
|
70
|
+
if node.parent is None:
|
|
71
|
+
if stop is None:
|
|
72
|
+
return
|
|
73
|
+
raise ValueError("Reached root without encountering stop node.")
|
|
74
|
+
node = node.parent
|
|
75
|
+
if node == stop:
|
|
76
|
+
return
|
|
77
|
+
yield node
|
|
78
|
+
|
|
79
|
+
def iter_upward(self, stop: "Tree | None" = None) -> Iterator["Tree"]:
|
|
80
|
+
if self == stop:
|
|
81
|
+
return
|
|
82
|
+
yield self
|
|
83
|
+
yield from self.iter_ancestors(stop=stop)
|
|
84
|
+
|
|
85
|
+
def iter_descendants(self) -> Iterator["Tree"]:
|
|
86
|
+
for child in self.children:
|
|
87
|
+
yield child
|
|
88
|
+
yield from child.iter_descendants()
|
|
89
|
+
|
|
90
|
+
def preorder_traversal(self) -> Iterator["Tree"]:
|
|
91
|
+
yield self
|
|
92
|
+
yield from self.iter_descendants()
|
|
93
|
+
|
|
93
94
|
def inorder_traversal(self) -> Iterator["Tree"]:
|
|
94
95
|
if self.is_leaf():
|
|
95
96
|
yield self
|
|
@@ -101,22 +102,11 @@ class Tree:
|
|
|
101
102
|
yield self
|
|
102
103
|
yield from right.inorder_traversal()
|
|
103
104
|
|
|
104
|
-
def preorder_traversal(self) -> Iterator["Tree"]:
|
|
105
|
-
yield self
|
|
106
|
-
for child in self.children:
|
|
107
|
-
yield from child.preorder_traversal()
|
|
108
|
-
|
|
109
105
|
def postorder_traversal(self) -> Iterator["Tree"]:
|
|
110
106
|
for child in self.children:
|
|
111
107
|
yield from child.postorder_traversal()
|
|
112
108
|
yield self
|
|
113
109
|
|
|
114
|
-
def iter_ancestors(self, stop: "Tree | None" = None) -> Iterator["Tree"]:
|
|
115
|
-
node = self
|
|
116
|
-
while node is not None and node is not stop:
|
|
117
|
-
yield node
|
|
118
|
-
node = node.parent
|
|
119
|
-
|
|
120
110
|
def breadth_first_traversal(self) -> Iterator["Tree"]:
|
|
121
111
|
queue: deque["Tree"] = deque([self])
|
|
122
112
|
while queue:
|
|
@@ -124,39 +114,75 @@ class Tree:
|
|
|
124
114
|
yield node
|
|
125
115
|
queue.extend(node.children)
|
|
126
116
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
raise ValueError(f"Node with name {name} not found.")
|
|
117
|
+
# ---------------
|
|
118
|
+
# Tree properties
|
|
119
|
+
# ---------------
|
|
120
|
+
# Properties and methods related to tree metrics like leaf count, depth, height, etc.
|
|
132
121
|
|
|
133
|
-
|
|
134
|
-
|
|
122
|
+
@property
|
|
123
|
+
def n_leaves(self) -> int:
|
|
124
|
+
return len(self.get_leaves())
|
|
135
125
|
|
|
136
|
-
def
|
|
137
|
-
|
|
126
|
+
def branch_length_or_raise(self) -> float:
|
|
127
|
+
if self.parent is None:
|
|
128
|
+
return 0 if self.branch_length is None else self.branch_length
|
|
129
|
+
if self.branch_length is None:
|
|
130
|
+
raise ValueError(f"Branch length of node {self.name} is not set.")
|
|
131
|
+
return self.branch_length
|
|
138
132
|
|
|
139
|
-
|
|
140
|
-
|
|
133
|
+
@property
|
|
134
|
+
def depth_level(self) -> int:
|
|
135
|
+
return 0 if self.parent is None else self.parent.depth_level + 1
|
|
141
136
|
|
|
142
|
-
|
|
143
|
-
|
|
137
|
+
@property
|
|
138
|
+
def depth(self) -> float:
|
|
139
|
+
parent_depth = 0 if self.parent is None else self.parent.depth
|
|
140
|
+
return parent_depth + self.branch_length_or_raise()
|
|
144
141
|
|
|
145
|
-
|
|
146
|
-
|
|
142
|
+
@property
|
|
143
|
+
def height_level(self) -> int:
|
|
144
|
+
if self.is_leaf():
|
|
145
|
+
return 0
|
|
146
|
+
return 1 + max(child.height_level for child in self.children)
|
|
147
147
|
|
|
148
|
-
|
|
148
|
+
@property
|
|
149
|
+
def height(self) -> float:
|
|
150
|
+
if self.is_leaf():
|
|
151
|
+
return 0.0
|
|
152
|
+
return max(
|
|
153
|
+
child.branch_length_or_raise() + child.height for child in self.children
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# -------------
|
|
157
|
+
# Miscellaneous
|
|
158
|
+
# -------------
|
|
159
|
+
# Other useful miscellaneous methods.
|
|
160
|
+
|
|
161
|
+
def ladderize(self, key: Callable[["Tree"], Any] | None = None) -> None:
|
|
162
|
+
if key is None:
|
|
163
|
+
key = lambda node: node.n_leaves
|
|
149
164
|
self._children.sort(key=key)
|
|
150
165
|
for child in self.children:
|
|
151
166
|
child.ladderize(key)
|
|
152
167
|
|
|
168
|
+
def get_node(self, name: str) -> "Tree":
|
|
169
|
+
for node in self:
|
|
170
|
+
if node.name == name:
|
|
171
|
+
return node
|
|
172
|
+
raise ValueError(f"Node with name {name} not found.")
|
|
173
|
+
|
|
153
174
|
def copy(self):
|
|
154
175
|
new_tree = Tree(self.name, self.branch_length)
|
|
155
|
-
new_tree.
|
|
176
|
+
new_tree.update(self.metadata)
|
|
156
177
|
for child in self.children:
|
|
157
178
|
new_tree.add_child(child.copy())
|
|
158
179
|
return new_tree
|
|
159
180
|
|
|
181
|
+
# ----------------
|
|
182
|
+
# Dunder methods
|
|
183
|
+
# ----------------
|
|
184
|
+
# Special methods for standard behaviors like iteration, length, and representation.
|
|
185
|
+
|
|
160
186
|
def __iter__(self) -> Iterator["Tree"]:
|
|
161
187
|
return self.preorder_traversal()
|
|
162
188
|
|
|
@@ -164,4 +190,4 @@ class Tree:
|
|
|
164
190
|
return sum(1 for _ in self)
|
|
165
191
|
|
|
166
192
|
def __repr__(self) -> str:
|
|
167
|
-
return f"TreeNode(name='{self.name}', branch_length={self.branch_length},
|
|
193
|
+
return f"TreeNode(name='{self.name}', branch_length={self.branch_length}, metadata={self.metadata})"
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from math import comb
|
|
3
|
+
|
|
4
|
+
from phylogenie.treesimulator.tree import Tree
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_node_leaf_counts(tree: Tree) -> dict[Tree, int]:
|
|
8
|
+
n_leaves: dict[Tree, int] = {}
|
|
9
|
+
for node in tree.postorder_traversal():
|
|
10
|
+
n_leaves[node] = sum(n_leaves[child] for child in node.children) or 1
|
|
11
|
+
return n_leaves
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_node_depth_levels(tree: Tree) -> dict[Tree, int]:
|
|
15
|
+
depth_levels: dict[Tree, int] = {tree: tree.depth_level}
|
|
16
|
+
for node in tree.iter_descendants():
|
|
17
|
+
depth_levels[node] = depth_levels[node.parent] + 1 # pyright: ignore
|
|
18
|
+
return depth_levels
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_node_depths(tree: Tree) -> dict[Tree, float]:
|
|
22
|
+
depths: dict[Tree, float] = {tree: tree.depth}
|
|
23
|
+
for node in tree.iter_descendants():
|
|
24
|
+
parent_depth = depths[node.parent] # pyright: ignore
|
|
25
|
+
depths[node] = node.branch_length_or_raise() + parent_depth
|
|
26
|
+
return depths
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_node_height_levels(tree: Tree) -> dict[Tree, int]:
|
|
30
|
+
height_levels: dict[Tree, int] = {}
|
|
31
|
+
for node in tree.postorder_traversal():
|
|
32
|
+
height_levels[node] = (
|
|
33
|
+
0
|
|
34
|
+
if node.is_leaf()
|
|
35
|
+
else max(1 + height_levels[child] for child in node.children)
|
|
36
|
+
)
|
|
37
|
+
return height_levels
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_node_heights(tree: Tree) -> dict[Tree, float]:
|
|
41
|
+
heights: dict[Tree, float] = {}
|
|
42
|
+
for node in tree.postorder_traversal():
|
|
43
|
+
heights[node] = (
|
|
44
|
+
0
|
|
45
|
+
if node.is_leaf()
|
|
46
|
+
else max(
|
|
47
|
+
child.branch_length_or_raise() + heights[child]
|
|
48
|
+
for child in node.children
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
return heights
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_mrca(node1: Tree, node2: Tree) -> Tree:
|
|
55
|
+
node1_ancestors = set(node1.iter_upward())
|
|
56
|
+
for node2_ancestor in node2.iter_upward():
|
|
57
|
+
if node2_ancestor in node1_ancestors:
|
|
58
|
+
return node2_ancestor
|
|
59
|
+
raise ValueError(f"No common ancestor found between node {node1} and node {node2}.")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_path(node1: Tree, node2: Tree) -> list[Tree]:
|
|
63
|
+
mrca = get_mrca(node1, node2)
|
|
64
|
+
return [
|
|
65
|
+
*node1.iter_upward(stop=mrca.parent),
|
|
66
|
+
*reversed(list(node2.iter_upward(stop=mrca))),
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def count_hops(node1: Tree, node2: Tree) -> int:
|
|
71
|
+
return len(get_path(node1, node2)) - 1
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_distance(node1: Tree, node2: Tree) -> float:
|
|
75
|
+
mrca = get_mrca(node1, node2)
|
|
76
|
+
path = get_path(node1, node2)
|
|
77
|
+
path.remove(mrca)
|
|
78
|
+
return sum(node.branch_length_or_raise() for node in path)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def compute_sackin_index(tree: Tree, normalize: bool = False) -> float:
|
|
82
|
+
depth_levels = get_node_depth_levels(tree)
|
|
83
|
+
sackin_index = sum(dl for node, dl in depth_levels.items() if node.is_leaf())
|
|
84
|
+
if normalize:
|
|
85
|
+
if not tree.is_binary():
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"Normalized Sackin index is only defined for binary trees."
|
|
88
|
+
)
|
|
89
|
+
n = tree.n_leaves
|
|
90
|
+
h = math.floor(math.log2(n))
|
|
91
|
+
min_sackin_index = n * (h + 2) - 2 ** (h + 1)
|
|
92
|
+
max_sackin_index = n * (n - 1) / 2
|
|
93
|
+
return (sackin_index - min_sackin_index) / (max_sackin_index - min_sackin_index)
|
|
94
|
+
return sackin_index
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def compute_mean_leaf_pairwise_distance(tree: Tree) -> float:
|
|
98
|
+
leaves = tree.get_leaves()
|
|
99
|
+
n_leaves = len(leaves)
|
|
100
|
+
if n_leaves < 2:
|
|
101
|
+
return 0.0
|
|
102
|
+
|
|
103
|
+
total_distance = sum(
|
|
104
|
+
get_distance(leaves[i], leaves[j])
|
|
105
|
+
for i in range(n_leaves)
|
|
106
|
+
for j in range(i + 1, n_leaves)
|
|
107
|
+
)
|
|
108
|
+
return total_distance / comb(n_leaves, 2)
|