phylogenie 2.1.4__py3-none-any.whl → 3.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +60 -14
- phylogenie/draw.py +690 -0
- phylogenie/generators/alisim.py +12 -12
- phylogenie/generators/configs.py +26 -4
- phylogenie/generators/dataset.py +3 -3
- phylogenie/generators/factories.py +38 -12
- phylogenie/generators/trees.py +48 -47
- phylogenie/io/__init__.py +3 -0
- phylogenie/io/fasta.py +34 -0
- phylogenie/main.py +27 -10
- phylogenie/mixins.py +33 -0
- phylogenie/skyline/matrix.py +11 -7
- phylogenie/skyline/parameter.py +12 -4
- phylogenie/skyline/vector.py +12 -6
- phylogenie/treesimulator/__init__.py +36 -3
- phylogenie/treesimulator/events/__init__.py +5 -5
- phylogenie/treesimulator/events/base.py +39 -0
- phylogenie/treesimulator/events/contact_tracing.py +38 -23
- phylogenie/treesimulator/events/core.py +21 -12
- phylogenie/treesimulator/events/mutations.py +46 -46
- phylogenie/treesimulator/features.py +49 -0
- phylogenie/treesimulator/gillespie.py +59 -55
- phylogenie/treesimulator/io/__init__.py +4 -0
- phylogenie/treesimulator/io/newick.py +104 -0
- phylogenie/treesimulator/io/nexus.py +50 -0
- phylogenie/treesimulator/model.py +25 -49
- phylogenie/treesimulator/tree.py +196 -0
- phylogenie/treesimulator/utils.py +108 -0
- phylogenie/typings.py +3 -3
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/METADATA +13 -15
- phylogenie-3.1.7.dist-info/RECORD +41 -0
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/WHEEL +2 -1
- phylogenie-3.1.7.dist-info/entry_points.txt +2 -0
- phylogenie-3.1.7.dist-info/top_level.txt +1 -0
- phylogenie/io.py +0 -107
- phylogenie/tree.py +0 -92
- phylogenie/utils.py +0 -17
- phylogenie-2.1.4.dist-info/RECORD +0 -32
- phylogenie-2.1.4.dist-info/entry_points.txt +0 -3
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -1,35 +1,36 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import time
|
|
3
|
-
from collections.abc import Sequence
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable
|
|
4
5
|
|
|
5
6
|
import joblib
|
|
6
7
|
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
7
9
|
from numpy.random import default_rng
|
|
8
10
|
from tqdm import tqdm
|
|
9
11
|
|
|
10
|
-
from phylogenie.
|
|
11
|
-
from phylogenie.
|
|
12
|
-
from phylogenie.treesimulator.
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
from phylogenie.treesimulator.events import Event
|
|
13
|
+
from phylogenie.treesimulator.features import Feature, set_features
|
|
14
|
+
from phylogenie.treesimulator.io import dump_newick
|
|
15
|
+
from phylogenie.treesimulator.model import Model
|
|
16
|
+
from phylogenie.treesimulator.tree import Tree
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
def simulate_tree(
|
|
18
20
|
events: Sequence[Event],
|
|
19
|
-
|
|
20
|
-
max_tips: int = MAX_TIPS,
|
|
21
|
+
n_tips: int | None = None,
|
|
21
22
|
max_time: float = np.inf,
|
|
22
23
|
init_state: str | None = None,
|
|
23
24
|
sampling_probability_at_present: float = 0.0,
|
|
24
25
|
seed: int | None = None,
|
|
25
26
|
timeout: float = np.inf,
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
if max_time == np.inf
|
|
27
|
+
acceptance_criterion: Callable[[Tree], bool] | None = None,
|
|
28
|
+
) -> tuple[Tree, dict[str, Any]]:
|
|
29
|
+
if (max_time != np.inf) == (n_tips is not None):
|
|
30
|
+
raise ValueError("Exactly one of max_time or n_tips must be specified.")
|
|
31
|
+
if sampling_probability_at_present and max_time == np.inf:
|
|
31
32
|
raise ValueError(
|
|
32
|
-
"sampling_probability_at_present
|
|
33
|
+
"sampling_probability_at_present can only be used with max_time."
|
|
33
34
|
)
|
|
34
35
|
|
|
35
36
|
states = {e.state for e in events if e.state}
|
|
@@ -45,31 +46,19 @@ def simulate_tree(
|
|
|
45
46
|
rng = default_rng(seed)
|
|
46
47
|
start_clock = time.perf_counter()
|
|
47
48
|
while True:
|
|
48
|
-
model = Model(init_state
|
|
49
|
+
model = Model(init_state)
|
|
50
|
+
metadata: dict[str, Any] = {}
|
|
51
|
+
run_events = list(events)
|
|
49
52
|
current_time = 0.0
|
|
50
53
|
change_times = sorted(set(t for e in events for t in e.rate.change_times))
|
|
51
54
|
next_change_time = change_times.pop(0) if change_times else np.inf
|
|
52
|
-
target_n_tips = rng.integers(min_tips, max_tips) if max_time == np.inf else None
|
|
53
55
|
|
|
54
|
-
while current_time < max_time:
|
|
56
|
+
while current_time < max_time and (n_tips is None or model.n_sampled < n_tips):
|
|
55
57
|
if time.perf_counter() - start_clock > timeout:
|
|
56
58
|
raise TimeoutError("Simulation timed out.")
|
|
57
59
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
instantaneous_events = [e for e, r in zip(events, rates) if r == np.inf]
|
|
62
|
-
if instantaneous_events:
|
|
63
|
-
event = instantaneous_events[rng.integers(len(instantaneous_events))]
|
|
64
|
-
event.apply(model, current_time, rng)
|
|
65
|
-
continue
|
|
66
|
-
|
|
67
|
-
if (
|
|
68
|
-
not any(rates)
|
|
69
|
-
or model.n_sampled > max_tips
|
|
70
|
-
or target_n_tips is not None
|
|
71
|
-
and model.n_sampled >= target_n_tips
|
|
72
|
-
):
|
|
60
|
+
rates = [e.get_propensity(model, current_time) for e in run_events]
|
|
61
|
+
if not any(rates):
|
|
73
62
|
break
|
|
74
63
|
|
|
75
64
|
time_step = rng.exponential(1 / sum(rates))
|
|
@@ -83,56 +72,71 @@ def simulate_tree(
|
|
|
83
72
|
current_time += time_step
|
|
84
73
|
|
|
85
74
|
event_idx = np.searchsorted(np.cumsum(rates) / sum(rates), rng.random())
|
|
86
|
-
|
|
75
|
+
event = run_events[int(event_idx)]
|
|
76
|
+
event_metadata = event.apply(model, run_events, current_time, rng)
|
|
77
|
+
if event_metadata is not None:
|
|
78
|
+
metadata.update(event_metadata)
|
|
79
|
+
|
|
80
|
+
if current_time != max_time and model.n_sampled != n_tips:
|
|
81
|
+
continue
|
|
87
82
|
|
|
88
83
|
for individual in model.get_population():
|
|
89
84
|
if rng.random() < sampling_probability_at_present:
|
|
90
85
|
model.sample(individual, current_time, True)
|
|
91
86
|
|
|
92
|
-
|
|
93
|
-
|
|
87
|
+
tree = model.get_sampled_tree()
|
|
88
|
+
if acceptance_criterion is None or acceptance_criterion(tree):
|
|
89
|
+
return (tree, metadata)
|
|
94
90
|
|
|
95
91
|
|
|
96
92
|
def generate_trees(
|
|
97
|
-
output_dir: str,
|
|
93
|
+
output_dir: str | Path,
|
|
98
94
|
n_trees: int,
|
|
99
|
-
events:
|
|
100
|
-
|
|
101
|
-
max_tips: int = 2**32,
|
|
95
|
+
events: Sequence[Event],
|
|
96
|
+
n_tips: int | None = None,
|
|
102
97
|
max_time: float = np.inf,
|
|
103
98
|
init_state: str | None = None,
|
|
104
99
|
sampling_probability_at_present: float = 0.0,
|
|
100
|
+
node_features: Iterable[Feature] | None = None,
|
|
105
101
|
seed: int | None = None,
|
|
106
102
|
n_jobs: int = -1,
|
|
107
103
|
timeout: float = np.inf,
|
|
108
|
-
|
|
109
|
-
|
|
104
|
+
acceptance_criterion: Callable[[Tree], bool] | None = None,
|
|
105
|
+
) -> pd.DataFrame:
|
|
106
|
+
if isinstance(output_dir, str):
|
|
107
|
+
output_dir = Path(output_dir)
|
|
108
|
+
if output_dir.exists():
|
|
109
|
+
raise FileExistsError(f"Output directory {output_dir} already exists")
|
|
110
|
+
output_dir.mkdir(parents=True)
|
|
111
|
+
|
|
112
|
+
def _simulate_tree(i: int, seed: int) -> dict[str, Any]:
|
|
110
113
|
while True:
|
|
111
114
|
try:
|
|
112
|
-
|
|
115
|
+
tree, metadata = simulate_tree(
|
|
113
116
|
events=events,
|
|
114
|
-
|
|
115
|
-
max_tips=max_tips,
|
|
117
|
+
n_tips=n_tips,
|
|
116
118
|
max_time=max_time,
|
|
117
119
|
init_state=init_state,
|
|
118
120
|
sampling_probability_at_present=sampling_probability_at_present,
|
|
119
121
|
seed=seed,
|
|
120
122
|
timeout=timeout,
|
|
123
|
+
acceptance_criterion=acceptance_criterion,
|
|
121
124
|
)
|
|
125
|
+
metadata["file_id"] = i
|
|
126
|
+
if node_features is not None:
|
|
127
|
+
set_features(tree, node_features)
|
|
128
|
+
dump_newick(tree, output_dir / f"{i}.nwk")
|
|
129
|
+
return metadata
|
|
122
130
|
except TimeoutError:
|
|
123
131
|
print("Simulation timed out, retrying with a different seed...")
|
|
124
132
|
seed += 1
|
|
125
133
|
|
|
126
|
-
if os.path.exists(output_dir):
|
|
127
|
-
raise FileExistsError(f"Output directory {output_dir} already exists")
|
|
128
|
-
os.mkdir(output_dir)
|
|
129
|
-
|
|
130
134
|
rng = default_rng(seed)
|
|
131
135
|
jobs = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")(
|
|
132
|
-
joblib.delayed(_simulate_tree)(seed=int(rng.integers(2**32)))
|
|
133
|
-
for
|
|
136
|
+
joblib.delayed(_simulate_tree)(i=i, seed=int(rng.integers(2**32)))
|
|
137
|
+
for i in range(n_trees)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return pd.DataFrame(
|
|
141
|
+
[md for md in tqdm(jobs, f"Generating trees in {output_dir}...", n_trees)]
|
|
134
142
|
)
|
|
135
|
-
for i, tree in tqdm(
|
|
136
|
-
enumerate(jobs), total=n_trees, desc=f"Generating trees in {output_dir}..."
|
|
137
|
-
):
|
|
138
|
-
dump_newick(tree, os.path.join(output_dir, f"{i}.nwk"))
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from phylogenie.treesimulator.tree import Tree
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def parse_newick(newick: str, translations: dict[str, str] | None = None) -> Tree:
|
|
8
|
+
newick = newick.strip()
|
|
9
|
+
newick = re.sub(r"^\[\&[^\]]*\]", "", newick).strip()
|
|
10
|
+
|
|
11
|
+
stack: list[list[Tree]] = []
|
|
12
|
+
current_children: list[Tree] = []
|
|
13
|
+
current_nodes: list[Tree] = []
|
|
14
|
+
i = 0
|
|
15
|
+
while True:
|
|
16
|
+
|
|
17
|
+
def _read_chars(stoppers: list[str]) -> str:
|
|
18
|
+
nonlocal i
|
|
19
|
+
chars = ""
|
|
20
|
+
while i < len(newick) and newick[i] not in stoppers:
|
|
21
|
+
chars += newick[i]
|
|
22
|
+
i += 1
|
|
23
|
+
if i == len(newick):
|
|
24
|
+
raise ValueError(f"Expected one of {stoppers}, got end of string")
|
|
25
|
+
return chars
|
|
26
|
+
|
|
27
|
+
if newick[i] == "(":
|
|
28
|
+
stack.append(current_nodes)
|
|
29
|
+
current_nodes = []
|
|
30
|
+
i += 1
|
|
31
|
+
continue
|
|
32
|
+
|
|
33
|
+
name = _read_chars([":", "[", ",", ")", ";"])
|
|
34
|
+
if translations is not None and name in translations:
|
|
35
|
+
name = translations[name]
|
|
36
|
+
current_node = Tree(name)
|
|
37
|
+
|
|
38
|
+
if newick[i] == "[":
|
|
39
|
+
i += 1
|
|
40
|
+
if newick[i] != "&":
|
|
41
|
+
raise ValueError("Expected '[&' at the start of node features")
|
|
42
|
+
i += 1
|
|
43
|
+
features = re.split(r",(?=[^,]+=)", _read_chars(["]"]))
|
|
44
|
+
i += 1
|
|
45
|
+
for feature in features:
|
|
46
|
+
key, value = feature.split("=")
|
|
47
|
+
try:
|
|
48
|
+
current_node.set(key, eval(value))
|
|
49
|
+
except Exception:
|
|
50
|
+
current_node.set(key, value)
|
|
51
|
+
|
|
52
|
+
if newick[i] == ":":
|
|
53
|
+
i += 1
|
|
54
|
+
current_node.branch_length = float(_read_chars([",", ")", ";"]))
|
|
55
|
+
|
|
56
|
+
for node in current_children:
|
|
57
|
+
current_node.add_child(node)
|
|
58
|
+
current_children = []
|
|
59
|
+
current_nodes.append(current_node)
|
|
60
|
+
|
|
61
|
+
if newick[i] == ")":
|
|
62
|
+
current_children = current_nodes
|
|
63
|
+
current_nodes = stack.pop()
|
|
64
|
+
elif newick[i] == ";":
|
|
65
|
+
return current_node
|
|
66
|
+
|
|
67
|
+
i += 1
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def load_newick(filepath: str | Path) -> Tree | list[Tree]:
|
|
71
|
+
with open(filepath, "r") as file:
|
|
72
|
+
trees = [parse_newick(newick) for newick in file]
|
|
73
|
+
return trees[0] if len(trees) == 1 else trees
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def to_newick(tree: Tree) -> str:
|
|
77
|
+
children_newick = ",".join([to_newick(child) for child in tree.children])
|
|
78
|
+
newick = tree.name
|
|
79
|
+
if tree.metadata:
|
|
80
|
+
reprs = {k: repr(v).replace("'", '"') for k, v in tree.metadata.items()}
|
|
81
|
+
for k, r in reprs.items():
|
|
82
|
+
if "," in k or "=" in k or "]" in k:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Invalid feature key `{k}`: keys must not contain ',', '=', or ']'"
|
|
85
|
+
)
|
|
86
|
+
if "=" in r or "]" in r:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Invalid value `{r}` for feature `{k}`: values must not contain '=' or ']'"
|
|
89
|
+
)
|
|
90
|
+
features = [f"{k}={repr}" for k, repr in reprs.items()]
|
|
91
|
+
newick += f"[&{','.join(features)}]"
|
|
92
|
+
if children_newick:
|
|
93
|
+
newick = f"({children_newick}){newick}"
|
|
94
|
+
if tree.branch_length is not None:
|
|
95
|
+
newick += f":{tree.branch_length}"
|
|
96
|
+
return newick
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def dump_newick(trees: Tree | list[Tree], filepath: str | Path) -> None:
|
|
100
|
+
if isinstance(trees, Tree):
|
|
101
|
+
trees = [trees]
|
|
102
|
+
with open(filepath, "w") as file:
|
|
103
|
+
for t in trees:
|
|
104
|
+
file.write(to_newick(t) + ";\n")
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from phylogenie.treesimulator.io.newick import parse_newick
|
|
6
|
+
from phylogenie.treesimulator.tree import Tree
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _parse_translate_block(lines: Iterator[str]) -> dict[str, str]:
|
|
10
|
+
translations: dict[str, str] = {}
|
|
11
|
+
for line in lines:
|
|
12
|
+
line = line.strip()
|
|
13
|
+
match = re.match(r"(\d+)\s+['\"]?([^'\",;]+)['\"]?", line)
|
|
14
|
+
if match is None:
|
|
15
|
+
if ";" in line:
|
|
16
|
+
return translations
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError("Invalid translate line. Expected '<num> <name>'.")
|
|
19
|
+
translations[match.group(1)] = match.group(2)
|
|
20
|
+
raise ValueError("Translate block not terminated with ';'.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _parse_trees_block(lines: Iterator[str]) -> dict[str, Tree]:
|
|
24
|
+
trees: dict[str, Tree] = {}
|
|
25
|
+
translations = {}
|
|
26
|
+
for line in lines:
|
|
27
|
+
line = line.strip()
|
|
28
|
+
if line.upper() == "TRANSLATE":
|
|
29
|
+
translations = _parse_translate_block(lines)
|
|
30
|
+
elif line.upper() == "END;":
|
|
31
|
+
return trees
|
|
32
|
+
else:
|
|
33
|
+
match = re.match(r"^TREE\s*\*?\s+(\S+)\s*=\s*(.+)$", line, re.IGNORECASE)
|
|
34
|
+
if match is None:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Invalid tree line. Expected 'TREE <name> = <newick>'."
|
|
37
|
+
)
|
|
38
|
+
name = match.group(1)
|
|
39
|
+
if name in trees:
|
|
40
|
+
raise ValueError(f"Duplicate tree name found: {name}.")
|
|
41
|
+
trees[name] = parse_newick(match.group(2), translations)
|
|
42
|
+
raise ValueError("Unterminated TREES block.")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_nexus(nexus_file: str | Path) -> dict[str, Tree]:
|
|
46
|
+
with open(nexus_file, "r") as f:
|
|
47
|
+
for line in f:
|
|
48
|
+
if line.strip().upper() == "BEGIN TREES;":
|
|
49
|
+
return _parse_trees_block(f)
|
|
50
|
+
raise ValueError("No TREES block found in the NEXUS file.")
|
|
@@ -1,33 +1,8 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
1
|
from collections import defaultdict
|
|
3
|
-
from collections.abc import Sequence
|
|
4
2
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Any
|
|
6
3
|
|
|
7
|
-
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
11
|
-
from phylogenie.tree import Tree
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class Event(ABC):
|
|
15
|
-
def __init__(self, state: str, rate: SkylineParameterLike):
|
|
16
|
-
self.state = state
|
|
17
|
-
self.rate = skyline_parameter(rate)
|
|
18
|
-
|
|
19
|
-
def draw_individual(self, model: "Model", rng: Generator) -> int:
|
|
20
|
-
return rng.choice(model.get_population(self.state))
|
|
21
|
-
|
|
22
|
-
def get_propensity(self, model: "Model", time: float) -> float:
|
|
23
|
-
n_individuals = model.count_individuals(self.state)
|
|
24
|
-
rate = self.rate.get_value_at_time(time)
|
|
25
|
-
if rate == np.inf and not n_individuals:
|
|
26
|
-
return 0
|
|
27
|
-
return rate * n_individuals
|
|
28
|
-
|
|
29
|
-
@abstractmethod
|
|
30
|
-
def apply(self, model: "Model", time: float, rng: Generator) -> None: ...
|
|
4
|
+
from phylogenie.mixins import MetadataMixin
|
|
5
|
+
from phylogenie.treesimulator.tree import Tree
|
|
31
6
|
|
|
32
7
|
|
|
33
8
|
@dataclass
|
|
@@ -37,31 +12,36 @@ class Individual:
|
|
|
37
12
|
state: str
|
|
38
13
|
|
|
39
14
|
|
|
40
|
-
|
|
41
|
-
|
|
15
|
+
def _get_node_name(node_id: int, state: str) -> str:
|
|
16
|
+
return f"{node_id}|{state}"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_node_state(node_name: str) -> str:
|
|
20
|
+
if "|" not in node_name:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Invalid node name: {node_name} (expected format 'id|state')."
|
|
23
|
+
)
|
|
24
|
+
return node_name.split("|")[-1]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Model(MetadataMixin):
|
|
28
|
+
def __init__(self, init_state: str):
|
|
29
|
+
super().__init__()
|
|
42
30
|
self._next_node_id = 0
|
|
43
31
|
self._next_individual_id = 0
|
|
44
32
|
self._population: dict[int, Individual] = {}
|
|
45
33
|
self._states: dict[str, set[int]] = defaultdict(set)
|
|
46
34
|
self._sampled: set[str] = set()
|
|
47
35
|
self._tree = self._get_new_individual(init_state).node
|
|
48
|
-
self._events = list(events)
|
|
49
|
-
self.context: dict[str, Any] = {}
|
|
50
36
|
|
|
51
37
|
@property
|
|
52
38
|
def n_sampled(self) -> int:
|
|
53
39
|
return len(self._sampled)
|
|
54
40
|
|
|
55
|
-
@property
|
|
56
|
-
def events(self) -> tuple[Event, ...]:
|
|
57
|
-
return tuple(self._events)
|
|
58
|
-
|
|
59
|
-
def add_event(self, event: Event) -> None:
|
|
60
|
-
self._events.append(event)
|
|
61
|
-
|
|
62
41
|
def _get_new_node(self, state: str) -> Tree:
|
|
63
42
|
self._next_node_id += 1
|
|
64
|
-
|
|
43
|
+
node = Tree(_get_node_name(self._next_node_id, state))
|
|
44
|
+
return node
|
|
65
45
|
|
|
66
46
|
def _get_new_individual(self, state: str) -> Individual:
|
|
67
47
|
self._next_individual_id += 1
|
|
@@ -74,10 +54,8 @@ class Model:
|
|
|
74
54
|
|
|
75
55
|
def _set_branch_length(self, node: Tree, time: float) -> None:
|
|
76
56
|
if node.branch_length is not None:
|
|
77
|
-
raise ValueError(f"Branch length of node {node.
|
|
78
|
-
node.branch_length =
|
|
79
|
-
time if node.parent is None else time - node.parent.get_time()
|
|
80
|
-
)
|
|
57
|
+
raise ValueError(f"Branch length of node {node.name} is already set.")
|
|
58
|
+
node.branch_length = time if node.parent is None else time - node.parent.depth
|
|
81
59
|
|
|
82
60
|
def _stem(self, individual: Individual, time: float) -> None:
|
|
83
61
|
self._set_branch_length(individual.node, time)
|
|
@@ -108,12 +86,12 @@ class Model:
|
|
|
108
86
|
def sample(self, id: int, time: float, removal: bool) -> None:
|
|
109
87
|
individual = self._population[id]
|
|
110
88
|
if removal:
|
|
111
|
-
self._sampled.add(individual.node.
|
|
89
|
+
self._sampled.add(individual.node.name)
|
|
112
90
|
self.remove(id, time)
|
|
113
91
|
else:
|
|
114
92
|
sample_node = self._get_new_node(individual.state)
|
|
115
93
|
sample_node.branch_length = 0.0
|
|
116
|
-
self._sampled.add(sample_node.
|
|
94
|
+
self._sampled.add(sample_node.name)
|
|
117
95
|
individual.node.add_child(sample_node)
|
|
118
96
|
self._stem(individual, time)
|
|
119
97
|
|
|
@@ -123,7 +101,7 @@ class Model:
|
|
|
123
101
|
def get_sampled_tree(self) -> Tree:
|
|
124
102
|
tree = self._tree.copy()
|
|
125
103
|
for node in list(tree.postorder_traversal()):
|
|
126
|
-
if node.
|
|
104
|
+
if node.name not in self._sampled and not node.children:
|
|
127
105
|
if node.parent is None:
|
|
128
106
|
raise ValueError("No samples in the tree.")
|
|
129
107
|
else:
|
|
@@ -131,9 +109,7 @@ class Model:
|
|
|
131
109
|
elif len(node.children) == 1:
|
|
132
110
|
(child,) = node.children
|
|
133
111
|
child.set_parent(node.parent)
|
|
134
|
-
|
|
135
|
-
assert node.branch_length is not None
|
|
136
|
-
child.branch_length += node.branch_length
|
|
112
|
+
child.branch_length += node.branch_length # pyright: ignore
|
|
137
113
|
if node.parent is None:
|
|
138
114
|
return child
|
|
139
115
|
else:
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Callable, Iterator
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from phylogenie.mixins import MetadataMixin
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Tree(MetadataMixin):
|
|
9
|
+
def __init__(self, name: str = "", branch_length: float | None = None):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.name = name
|
|
12
|
+
self.branch_length = branch_length
|
|
13
|
+
self._parent: Tree | None = None
|
|
14
|
+
self._children: list[Tree] = []
|
|
15
|
+
|
|
16
|
+
# ----------------
|
|
17
|
+
# Basic properties
|
|
18
|
+
# ----------------
|
|
19
|
+
# Properties related to parent-child relationships.
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def children(self) -> tuple["Tree", ...]:
|
|
23
|
+
return tuple(self._children)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def parent(self) -> "Tree | None":
|
|
27
|
+
return self._parent
|
|
28
|
+
|
|
29
|
+
def add_child(self, child: "Tree") -> "Tree":
|
|
30
|
+
if child.parent is not None:
|
|
31
|
+
raise ValueError(f"Node {child.name} already has a parent.")
|
|
32
|
+
child._parent = self
|
|
33
|
+
self._children.append(child)
|
|
34
|
+
return self
|
|
35
|
+
|
|
36
|
+
def remove_child(self, child: "Tree") -> None:
|
|
37
|
+
self._children.remove(child)
|
|
38
|
+
child._parent = None
|
|
39
|
+
|
|
40
|
+
def set_parent(self, parent: "Tree | None"):
|
|
41
|
+
if self.parent is not None:
|
|
42
|
+
self.parent.remove_child(self)
|
|
43
|
+
self._parent = parent
|
|
44
|
+
if parent is not None:
|
|
45
|
+
parent._children.append(self)
|
|
46
|
+
|
|
47
|
+
def is_leaf(self) -> bool:
|
|
48
|
+
return not self.children
|
|
49
|
+
|
|
50
|
+
def get_leaves(self) -> tuple["Tree", ...]:
|
|
51
|
+
return tuple(node for node in self if node.is_leaf())
|
|
52
|
+
|
|
53
|
+
def is_internal(self) -> bool:
|
|
54
|
+
return not self.is_leaf()
|
|
55
|
+
|
|
56
|
+
def get_internal_nodes(self) -> tuple["Tree", ...]:
|
|
57
|
+
return tuple(node for node in self if node.is_internal())
|
|
58
|
+
|
|
59
|
+
def is_binary(self) -> bool:
|
|
60
|
+
return all(len(node.children) in (0, 2) for node in self)
|
|
61
|
+
|
|
62
|
+
# --------------
|
|
63
|
+
# Tree traversal
|
|
64
|
+
# --------------
|
|
65
|
+
# Methods for traversing the tree in various orders.
|
|
66
|
+
|
|
67
|
+
def iter_ancestors(self, stop: "Tree | None" = None) -> Iterator["Tree"]:
|
|
68
|
+
node = self
|
|
69
|
+
while True:
|
|
70
|
+
if node.parent is None:
|
|
71
|
+
if stop is None:
|
|
72
|
+
return
|
|
73
|
+
raise ValueError("Reached root without encountering stop node.")
|
|
74
|
+
node = node.parent
|
|
75
|
+
if node == stop:
|
|
76
|
+
return
|
|
77
|
+
yield node
|
|
78
|
+
|
|
79
|
+
def iter_upward(self, stop: "Tree | None" = None) -> Iterator["Tree"]:
|
|
80
|
+
if self == stop:
|
|
81
|
+
return
|
|
82
|
+
yield self
|
|
83
|
+
yield from self.iter_ancestors(stop=stop)
|
|
84
|
+
|
|
85
|
+
def iter_descendants(self) -> Iterator["Tree"]:
|
|
86
|
+
for child in self.children:
|
|
87
|
+
yield child
|
|
88
|
+
yield from child.iter_descendants()
|
|
89
|
+
|
|
90
|
+
def preorder_traversal(self) -> Iterator["Tree"]:
|
|
91
|
+
yield self
|
|
92
|
+
yield from self.iter_descendants()
|
|
93
|
+
|
|
94
|
+
def inorder_traversal(self) -> Iterator["Tree"]:
|
|
95
|
+
if self.is_leaf():
|
|
96
|
+
yield self
|
|
97
|
+
return
|
|
98
|
+
if len(self.children) != 2:
|
|
99
|
+
raise ValueError("Inorder traversal is only defined for binary trees.")
|
|
100
|
+
left, right = self.children
|
|
101
|
+
yield from left.inorder_traversal()
|
|
102
|
+
yield self
|
|
103
|
+
yield from right.inorder_traversal()
|
|
104
|
+
|
|
105
|
+
def postorder_traversal(self) -> Iterator["Tree"]:
|
|
106
|
+
for child in self.children:
|
|
107
|
+
yield from child.postorder_traversal()
|
|
108
|
+
yield self
|
|
109
|
+
|
|
110
|
+
def breadth_first_traversal(self) -> Iterator["Tree"]:
|
|
111
|
+
queue: deque["Tree"] = deque([self])
|
|
112
|
+
while queue:
|
|
113
|
+
node = queue.popleft()
|
|
114
|
+
yield node
|
|
115
|
+
queue.extend(node.children)
|
|
116
|
+
|
|
117
|
+
# ---------------
|
|
118
|
+
# Tree properties
|
|
119
|
+
# ---------------
|
|
120
|
+
# Properties and methods related to tree metrics like leaf count, depth, height, etc.
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def n_leaves(self) -> int:
|
|
124
|
+
return len(self.get_leaves())
|
|
125
|
+
|
|
126
|
+
def branch_length_or_raise(self) -> float:
|
|
127
|
+
if self.parent is None:
|
|
128
|
+
return 0 if self.branch_length is None else self.branch_length
|
|
129
|
+
if self.branch_length is None:
|
|
130
|
+
raise ValueError(f"Branch length of node {self.name} is not set.")
|
|
131
|
+
return self.branch_length
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def depth_level(self) -> int:
|
|
135
|
+
return 0 if self.parent is None else self.parent.depth_level + 1
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def depth(self) -> float:
|
|
139
|
+
parent_depth = 0 if self.parent is None else self.parent.depth
|
|
140
|
+
return parent_depth + self.branch_length_or_raise()
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def height_level(self) -> int:
|
|
144
|
+
if self.is_leaf():
|
|
145
|
+
return 0
|
|
146
|
+
return 1 + max(child.height_level for child in self.children)
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def height(self) -> float:
|
|
150
|
+
if self.is_leaf():
|
|
151
|
+
return 0.0
|
|
152
|
+
return max(
|
|
153
|
+
child.branch_length_or_raise() + child.height for child in self.children
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# -------------
|
|
157
|
+
# Miscellaneous
|
|
158
|
+
# -------------
|
|
159
|
+
# Other useful miscellaneous methods.
|
|
160
|
+
|
|
161
|
+
def ladderize(self, key: Callable[["Tree"], Any] | None = None) -> None:
|
|
162
|
+
def _default_key(node: Tree) -> int:
|
|
163
|
+
return node.n_leaves
|
|
164
|
+
|
|
165
|
+
if key is None:
|
|
166
|
+
key = _default_key
|
|
167
|
+
self._children.sort(key=key)
|
|
168
|
+
for child in self.children:
|
|
169
|
+
child.ladderize(key)
|
|
170
|
+
|
|
171
|
+
def get_node(self, name: str) -> "Tree":
|
|
172
|
+
for node in self:
|
|
173
|
+
if node.name == name:
|
|
174
|
+
return node
|
|
175
|
+
raise ValueError(f"Node {name} not found.")
|
|
176
|
+
|
|
177
|
+
def copy(self):
|
|
178
|
+
new_tree = Tree(self.name, self.branch_length)
|
|
179
|
+
new_tree.update(self.metadata)
|
|
180
|
+
for child in self.children:
|
|
181
|
+
new_tree.add_child(child.copy())
|
|
182
|
+
return new_tree
|
|
183
|
+
|
|
184
|
+
# ----------------
|
|
185
|
+
# Dunder methods
|
|
186
|
+
# ----------------
|
|
187
|
+
# Special methods for standard behaviors like iteration, length, and representation.
|
|
188
|
+
|
|
189
|
+
def __iter__(self) -> Iterator["Tree"]:
|
|
190
|
+
return self.preorder_traversal()
|
|
191
|
+
|
|
192
|
+
def __len__(self) -> int:
|
|
193
|
+
return sum(1 for _ in self)
|
|
194
|
+
|
|
195
|
+
def __repr__(self) -> str:
|
|
196
|
+
return f"TreeNode(name='{self.name}', branch_length={self.branch_length}, metadata={self.metadata})"
|