phylogenie 2.1.30__py3-none-any.whl → 3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +35 -28
- phylogenie/draw.py +22 -29
- phylogenie/generators/alisim.py +7 -8
- phylogenie/generators/configs.py +23 -1
- phylogenie/generators/dataset.py +1 -1
- phylogenie/generators/factories.py +10 -13
- phylogenie/generators/trees.py +19 -10
- phylogenie/io/__init__.py +1 -3
- phylogenie/mixins.py +41 -0
- phylogenie/treesimulator/__init__.py +29 -3
- phylogenie/treesimulator/events/base.py +0 -3
- phylogenie/treesimulator/events/contact_tracing.py +8 -9
- phylogenie/treesimulator/events/mutations.py +7 -8
- phylogenie/treesimulator/features.py +3 -3
- phylogenie/treesimulator/gillespie.py +21 -37
- phylogenie/treesimulator/io/__init__.py +4 -0
- phylogenie/{io → treesimulator/io}/newick.py +3 -3
- phylogenie/{io → treesimulator/io}/nexus.py +2 -2
- phylogenie/treesimulator/model.py +7 -10
- phylogenie/{tree.py → treesimulator/tree.py} +110 -84
- phylogenie/treesimulator/utils.py +108 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/METADATA +1 -1
- phylogenie-3.1.1.dist-info/RECORD +40 -0
- phylogenie/models.py +0 -17
- phylogenie/utils.py +0 -176
- phylogenie-2.1.30.dist-info/RECORD +0 -39
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/LICENSE.txt +0 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/WHEEL +0 -0
- {phylogenie-2.1.30.dist-info → phylogenie-3.1.1.dist-info}/entry_points.txt +0 -0
phylogenie/__init__.py
CHANGED
|
@@ -11,7 +11,7 @@ from phylogenie.generators import (
|
|
|
11
11
|
FBDTreeDatasetGenerator,
|
|
12
12
|
TreeDatasetGeneratorConfig,
|
|
13
13
|
)
|
|
14
|
-
from phylogenie.io import
|
|
14
|
+
from phylogenie.io import load_fasta
|
|
15
15
|
from phylogenie.msa import MSA
|
|
16
16
|
from phylogenie.skyline import (
|
|
17
17
|
SkylineMatrix,
|
|
@@ -25,41 +25,47 @@ from phylogenie.skyline import (
|
|
|
25
25
|
skyline_parameter,
|
|
26
26
|
skyline_vector,
|
|
27
27
|
)
|
|
28
|
-
from phylogenie.tree import Tree
|
|
29
28
|
from phylogenie.treesimulator import (
|
|
30
29
|
Birth,
|
|
31
30
|
BirthWithContactTracing,
|
|
32
31
|
Death,
|
|
33
32
|
Event,
|
|
34
33
|
EventType,
|
|
34
|
+
Feature,
|
|
35
35
|
Migration,
|
|
36
36
|
Mutation,
|
|
37
37
|
Sampling,
|
|
38
38
|
SamplingWithContactTracing,
|
|
39
|
+
Tree,
|
|
40
|
+
compute_mean_leaf_pairwise_distance,
|
|
41
|
+
compute_sackin_index,
|
|
42
|
+
dump_newick,
|
|
39
43
|
generate_trees,
|
|
40
44
|
get_BD_events,
|
|
41
45
|
get_BDEI_events,
|
|
42
46
|
get_BDSS_events,
|
|
43
47
|
get_canonical_events,
|
|
44
48
|
get_contact_tracing_events,
|
|
49
|
+
get_distance,
|
|
45
50
|
get_epidemiological_events,
|
|
46
51
|
get_FBD_events,
|
|
47
|
-
simulate_tree,
|
|
48
|
-
)
|
|
49
|
-
from phylogenie.utils import (
|
|
50
|
-
compute_colless_index,
|
|
51
|
-
compute_mean_leaf_pairwise_distance,
|
|
52
|
-
compute_sackin_index,
|
|
53
|
-
get_distance,
|
|
54
52
|
get_mrca,
|
|
53
|
+
get_mutation_id,
|
|
55
54
|
get_node_depth_levels,
|
|
56
55
|
get_node_depths,
|
|
57
56
|
get_node_height_levels,
|
|
58
57
|
get_node_heights,
|
|
59
58
|
get_node_leaf_counts,
|
|
59
|
+
get_node_state,
|
|
60
|
+
load_newick,
|
|
61
|
+
load_nexus,
|
|
62
|
+
set_features,
|
|
63
|
+
simulate_tree,
|
|
60
64
|
)
|
|
61
65
|
|
|
62
66
|
__all__ = [
|
|
67
|
+
"Coloring",
|
|
68
|
+
"draw_tree",
|
|
63
69
|
"AliSimDatasetGenerator",
|
|
64
70
|
"BDEITreeDatasetGenerator",
|
|
65
71
|
"BDSSTreeDatasetGenerator",
|
|
@@ -69,51 +75,52 @@ __all__ = [
|
|
|
69
75
|
"DatasetGeneratorConfig",
|
|
70
76
|
"EpidemiologicalTreeDatasetGenerator",
|
|
71
77
|
"FBDTreeDatasetGenerator",
|
|
78
|
+
"TreeDatasetGeneratorConfig",
|
|
79
|
+
"load_fasta",
|
|
80
|
+
"MSA",
|
|
72
81
|
"SkylineMatrix",
|
|
73
82
|
"SkylineMatrixCoercible",
|
|
74
|
-
"skyline_matrix",
|
|
75
83
|
"SkylineParameter",
|
|
76
84
|
"SkylineParameterLike",
|
|
77
|
-
"skyline_parameter",
|
|
78
85
|
"SkylineVector",
|
|
79
86
|
"SkylineVectorCoercible",
|
|
80
87
|
"SkylineVectorLike",
|
|
88
|
+
"skyline_matrix",
|
|
89
|
+
"skyline_parameter",
|
|
81
90
|
"skyline_vector",
|
|
82
|
-
"Tree",
|
|
83
|
-
"TreeDatasetGeneratorConfig",
|
|
84
91
|
"Birth",
|
|
85
92
|
"BirthWithContactTracing",
|
|
86
93
|
"Death",
|
|
87
94
|
"Event",
|
|
88
95
|
"EventType",
|
|
96
|
+
"Feature",
|
|
89
97
|
"Migration",
|
|
90
98
|
"Mutation",
|
|
91
99
|
"Sampling",
|
|
92
100
|
"SamplingWithContactTracing",
|
|
101
|
+
"Tree",
|
|
102
|
+
"compute_mean_leaf_pairwise_distance",
|
|
103
|
+
"compute_sackin_index",
|
|
104
|
+
"dump_newick",
|
|
105
|
+
"generate_trees",
|
|
93
106
|
"get_BD_events",
|
|
94
107
|
"get_BDEI_events",
|
|
95
108
|
"get_BDSS_events",
|
|
96
109
|
"get_canonical_events",
|
|
97
110
|
"get_contact_tracing_events",
|
|
111
|
+
"get_distance",
|
|
98
112
|
"get_epidemiological_events",
|
|
99
113
|
"get_FBD_events",
|
|
100
|
-
"generate_trees",
|
|
101
|
-
"simulate_tree",
|
|
102
|
-
"dump_newick",
|
|
103
|
-
"load_nexus",
|
|
104
|
-
"load_fasta",
|
|
105
|
-
"load_newick",
|
|
106
|
-
"MSA",
|
|
107
|
-
"Coloring",
|
|
108
|
-
"draw_tree",
|
|
109
|
-
"compute_colless_index",
|
|
110
|
-
"compute_mean_leaf_pairwise_distance",
|
|
111
|
-
"compute_sackin_index",
|
|
112
|
-
"get_distance",
|
|
113
114
|
"get_mrca",
|
|
114
|
-
"
|
|
115
|
+
"get_mutation_id",
|
|
115
116
|
"get_node_depth_levels",
|
|
116
|
-
"
|
|
117
|
+
"get_node_depths",
|
|
117
118
|
"get_node_height_levels",
|
|
119
|
+
"get_node_heights",
|
|
118
120
|
"get_node_leaf_counts",
|
|
121
|
+
"get_node_state",
|
|
122
|
+
"load_newick",
|
|
123
|
+
"load_nexus",
|
|
124
|
+
"set_features",
|
|
125
|
+
"simulate_tree",
|
|
119
126
|
]
|
phylogenie/draw.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from
|
|
3
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Callable
|
|
4
3
|
|
|
5
4
|
import matplotlib.colors as mcolors
|
|
6
5
|
import matplotlib.patches as mpatches
|
|
@@ -8,8 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
8
7
|
from matplotlib.axes import Axes
|
|
9
8
|
from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
|
|
10
9
|
|
|
11
|
-
from phylogenie.
|
|
12
|
-
from phylogenie.utils import get_node_depth_levels, get_node_depths
|
|
10
|
+
from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
class Coloring(str, Enum):
|
|
@@ -26,7 +24,7 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
|
|
|
26
24
|
|
|
27
25
|
xs = (
|
|
28
26
|
get_node_depth_levels(tree)
|
|
29
|
-
if any(node.branch_length is None for node in
|
|
27
|
+
if any(node.branch_length is None for node in tree.iter_descendants())
|
|
30
28
|
else get_node_depths(tree)
|
|
31
29
|
)
|
|
32
30
|
ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
|
|
@@ -50,7 +48,7 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
|
|
|
50
48
|
def draw_tree(
|
|
51
49
|
tree: Tree,
|
|
52
50
|
ax: Axes | None = None,
|
|
53
|
-
color_by: str | None = None,
|
|
51
|
+
color_by: str | dict[Tree, Any] | None = None,
|
|
54
52
|
coloring: str | Coloring | None = None,
|
|
55
53
|
default_color: Color = "black",
|
|
56
54
|
cmap: str | None = None,
|
|
@@ -69,33 +67,35 @@ def draw_tree(
|
|
|
69
67
|
if color_by is None:
|
|
70
68
|
return _draw_colored_tree(tree, ax, colors=default_color)
|
|
71
69
|
|
|
72
|
-
|
|
70
|
+
if isinstance(color_by, dict):
|
|
71
|
+
features = {node: color_by[node] for node in tree if node in color_by}
|
|
72
|
+
else:
|
|
73
|
+
features = {node: node[color_by] for node in tree if color_by in node.metadata}
|
|
73
74
|
|
|
74
75
|
if coloring is None:
|
|
75
76
|
coloring = (
|
|
76
77
|
Coloring.CONTINUOUS
|
|
77
|
-
if any(isinstance(f, float) for f in features)
|
|
78
|
+
if any(isinstance(f, float) for f in features.values())
|
|
78
79
|
else Coloring.DISCRETE
|
|
79
80
|
)
|
|
80
81
|
|
|
82
|
+
def _get_colors(feature_map: Callable[[Any], Color]) -> dict[Tree, Color]:
|
|
83
|
+
return {
|
|
84
|
+
node: feature_map(features[node]) if node in features else default_color
|
|
85
|
+
for node in tree
|
|
86
|
+
}
|
|
87
|
+
|
|
81
88
|
if coloring == Coloring.DISCRETE:
|
|
82
|
-
if any(isinstance(f, float) for f in features):
|
|
89
|
+
if any(isinstance(f, float) for f in features.values()):
|
|
83
90
|
raise ValueError(
|
|
84
91
|
"Discrete coloring selected but feature values are not all categorical."
|
|
85
92
|
)
|
|
86
93
|
|
|
87
94
|
colormap = plt.get_cmap("tab20" if cmap is None else cmap)
|
|
88
95
|
feature_colors = {
|
|
89
|
-
f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features))
|
|
90
|
-
}
|
|
91
|
-
colors = {
|
|
92
|
-
node: (
|
|
93
|
-
feature_colors[node.get(color_by)]
|
|
94
|
-
if color_by in node.features
|
|
95
|
-
else default_color
|
|
96
|
-
)
|
|
97
|
-
for node in tree
|
|
96
|
+
f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
|
|
98
97
|
}
|
|
98
|
+
colors = _get_colors(lambda f: feature_colors[f])
|
|
99
99
|
|
|
100
100
|
if show_legend:
|
|
101
101
|
legend_handles = [
|
|
@@ -105,7 +105,7 @@ def draw_tree(
|
|
|
105
105
|
)
|
|
106
106
|
for f in feature_colors
|
|
107
107
|
]
|
|
108
|
-
if any(color_by not in node.
|
|
108
|
+
if any(color_by not in node.metadata for node in tree):
|
|
109
109
|
legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
|
|
110
110
|
if legend_kwargs is None:
|
|
111
111
|
legend_kwargs = {}
|
|
@@ -114,18 +114,11 @@ def draw_tree(
|
|
|
114
114
|
return _draw_colored_tree(tree, ax, colors)
|
|
115
115
|
|
|
116
116
|
if coloring == Coloring.CONTINUOUS:
|
|
117
|
-
vmin = min(features) if vmin is None else vmin
|
|
118
|
-
vmax = max(features) if vmax is None else vmax
|
|
117
|
+
vmin = min(features.values()) if vmin is None else vmin
|
|
118
|
+
vmax = max(features.values()) if vmax is None else vmax
|
|
119
119
|
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
|
|
120
120
|
colormap = plt.get_cmap("viridis" if cmap is None else cmap)
|
|
121
|
-
colors =
|
|
122
|
-
node: (
|
|
123
|
-
colormap(norm(float(node.get(color_by))))
|
|
124
|
-
if color_by in node.features
|
|
125
|
-
else default_color
|
|
126
|
-
)
|
|
127
|
-
for node in tree
|
|
128
|
-
}
|
|
121
|
+
colors = _get_colors(lambda f: colormap(norm(float(f))))
|
|
129
122
|
|
|
130
123
|
if show_hist:
|
|
131
124
|
default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
|
phylogenie/generators/alisim.py
CHANGED
|
@@ -8,8 +8,7 @@ from numpy.random import Generator, default_rng
|
|
|
8
8
|
from phylogenie.generators.dataset import DatasetGenerator, DataType
|
|
9
9
|
from phylogenie.generators.factories import data, string
|
|
10
10
|
from phylogenie.generators.trees import TreeDatasetGeneratorConfig
|
|
11
|
-
from phylogenie.
|
|
12
|
-
from phylogenie.utils import get_node_depths
|
|
11
|
+
from phylogenie.treesimulator import dump_newick, get_node_depths
|
|
13
12
|
|
|
14
13
|
MSAS_DIRNAME = "MSAs"
|
|
15
14
|
TREES_DIRNAME = "trees"
|
|
@@ -60,26 +59,26 @@ class AliSimDatasetGenerator(DatasetGenerator):
|
|
|
60
59
|
tree_filename = f"{filename}.temp-tree"
|
|
61
60
|
msa_filename = filename
|
|
62
61
|
|
|
63
|
-
|
|
62
|
+
md: dict[str, Any] = {"file_id": Path(msa_filename).stem}
|
|
64
63
|
rng = default_rng(seed)
|
|
65
64
|
while True:
|
|
66
|
-
|
|
65
|
+
md.update(data(context, rng))
|
|
67
66
|
try:
|
|
68
|
-
tree, metadata = self.trees.simulate_one(
|
|
67
|
+
tree, metadata = self.trees.simulate_one(md, seed)
|
|
69
68
|
break
|
|
70
69
|
except TimeoutError:
|
|
71
70
|
print(
|
|
72
71
|
"Tree simulation timed out, retrying with different parameters..."
|
|
73
72
|
)
|
|
74
|
-
|
|
73
|
+
md.update(metadata)
|
|
75
74
|
|
|
76
75
|
times = get_node_depths(tree)
|
|
77
76
|
for leaf in tree.get_leaves():
|
|
78
77
|
leaf.name += f"|{times[leaf]}"
|
|
79
78
|
dump_newick(tree, f"{tree_filename}.nwk")
|
|
80
79
|
|
|
81
|
-
self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng,
|
|
80
|
+
self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, md)
|
|
82
81
|
if not self.keep_trees:
|
|
83
82
|
os.remove(f"{tree_filename}.nwk")
|
|
84
83
|
|
|
85
|
-
return
|
|
84
|
+
return md
|
phylogenie/generators/configs.py
CHANGED
|
@@ -1,7 +1,29 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from numpy.random import Generator
|
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
|
5
|
+
|
|
1
6
|
import phylogenie.typings as pgt
|
|
2
|
-
from phylogenie.models import Distribution, StrictBaseModel
|
|
3
7
|
from phylogenie.treesimulator import EventType
|
|
4
8
|
|
|
9
|
+
|
|
10
|
+
class StrictBaseModel(BaseModel):
|
|
11
|
+
model_config = ConfigDict(extra="forbid")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Distribution(BaseModel):
|
|
15
|
+
type: str
|
|
16
|
+
model_config = ConfigDict(extra="allow")
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def args(self) -> dict[str, Any]:
|
|
20
|
+
assert self.model_extra is not None
|
|
21
|
+
return self.model_extra
|
|
22
|
+
|
|
23
|
+
def __call__(self, rng: Generator) -> Any:
|
|
24
|
+
return getattr(rng, self.type)(**self.args)
|
|
25
|
+
|
|
26
|
+
|
|
5
27
|
Integer = str | int
|
|
6
28
|
Scalar = str | pgt.Scalar
|
|
7
29
|
ManyScalars = str | pgt.Many[Scalar]
|
phylogenie/generators/dataset.py
CHANGED
|
@@ -8,7 +8,7 @@ import pandas as pd
|
|
|
8
8
|
from numpy.random import Generator, default_rng
|
|
9
9
|
from tqdm import tqdm
|
|
10
10
|
|
|
11
|
-
from phylogenie.
|
|
11
|
+
from phylogenie.generators.configs import Distribution, StrictBaseModel
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class DataType(str, Enum):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Callable
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from numpy.random import Generator
|
|
@@ -8,7 +8,6 @@ import phylogenie.generators.configs as cfg
|
|
|
8
8
|
import phylogenie.generators.typeguards as ctg
|
|
9
9
|
import phylogenie.typeguards as tg
|
|
10
10
|
import phylogenie.typings as pgt
|
|
11
|
-
from phylogenie.models import Distribution
|
|
12
11
|
from phylogenie.skyline import (
|
|
13
12
|
SkylineMatrix,
|
|
14
13
|
SkylineMatrixCoercible,
|
|
@@ -24,11 +23,7 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
|
24
23
|
return np.array(
|
|
25
24
|
eval(
|
|
26
25
|
expression,
|
|
27
|
-
{
|
|
28
|
-
"__builtins__": __builtins__,
|
|
29
|
-
"np": np,
|
|
30
|
-
**{k: np.array(v) for k, v in data.items()},
|
|
31
|
-
},
|
|
26
|
+
{"np": np, **{k: np.array(v) for k, v in data.items()}},
|
|
32
27
|
)
|
|
33
28
|
).tolist()
|
|
34
29
|
|
|
@@ -214,12 +209,12 @@ def skyline_matrix(
|
|
|
214
209
|
return SkylineMatrix(value=value, change_times=change_times)
|
|
215
210
|
|
|
216
211
|
|
|
217
|
-
def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
|
|
212
|
+
def distribution(x: cfg.Distribution, data: dict[str, Any]) -> cfg.Distribution:
|
|
218
213
|
args = x.args
|
|
219
214
|
for arg_name, arg_value in args.items():
|
|
220
215
|
if isinstance(arg_value, str):
|
|
221
216
|
args[arg_name] = _eval_expression(arg_value, data)
|
|
222
|
-
return Distribution(type=x.type, **args)
|
|
217
|
+
return cfg.Distribution(type=x.type, **args)
|
|
223
218
|
|
|
224
219
|
|
|
225
220
|
def mutations(
|
|
@@ -227,11 +222,14 @@ def mutations(
|
|
|
227
222
|
data: dict[str, Any],
|
|
228
223
|
states: set[str],
|
|
229
224
|
rates_to_log: list[EventType] | None,
|
|
225
|
+
rng: Generator,
|
|
230
226
|
) -> list[Mutation]:
|
|
231
227
|
mutations: list[Mutation] = []
|
|
232
228
|
for m in x:
|
|
233
229
|
rate = skyline_parameter(m.rate, data)
|
|
234
|
-
rate_scalers
|
|
230
|
+
rate_scalers: dict[EventType, Callable[[], float]] = {
|
|
231
|
+
k: lambda: distribution(v, data)(rng) for k, v in m.rate_scalers.items()
|
|
232
|
+
}
|
|
235
233
|
if m.state is None:
|
|
236
234
|
mutations.extend(
|
|
237
235
|
Mutation(s, rate, rate_scalers, rates_to_log) for s in states
|
|
@@ -241,11 +239,10 @@ def mutations(
|
|
|
241
239
|
return mutations
|
|
242
240
|
|
|
243
241
|
|
|
244
|
-
def data(context: dict[str, Distribution] | None, rng: Generator) -> dict[str, Any]:
|
|
242
|
+
def data(context: dict[str, cfg.Distribution] | None, rng: Generator) -> dict[str, Any]:
|
|
245
243
|
if context is None:
|
|
246
244
|
return {}
|
|
247
245
|
data: dict[str, Any] = {}
|
|
248
246
|
for k, v in context.items():
|
|
249
|
-
|
|
250
|
-
data[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
|
|
247
|
+
data[k] = np.array(distribution(v, data)(rng)).tolist()
|
|
251
248
|
return data
|
phylogenie/generators/trees.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Annotated, Any, Literal
|
|
4
|
+
from typing import Annotated, Any, Callable, Literal
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from numpy.random import default_rng
|
|
8
8
|
from pydantic import Field
|
|
9
9
|
|
|
10
10
|
import phylogenie.generators.configs as cfg
|
|
11
|
+
from phylogenie.generators.configs import Distribution
|
|
11
12
|
from phylogenie.generators.dataset import DatasetGenerator, DataType
|
|
12
13
|
from phylogenie.generators.factories import (
|
|
13
14
|
data,
|
|
@@ -18,13 +19,12 @@ from phylogenie.generators.factories import (
|
|
|
18
19
|
skyline_parameter,
|
|
19
20
|
skyline_vector,
|
|
20
21
|
)
|
|
21
|
-
from phylogenie.io import dump_newick
|
|
22
|
-
from phylogenie.models import Distribution
|
|
23
|
-
from phylogenie.tree import Tree
|
|
24
22
|
from phylogenie.treesimulator import (
|
|
25
23
|
Event,
|
|
26
24
|
EventType,
|
|
27
25
|
Feature,
|
|
26
|
+
Tree,
|
|
27
|
+
dump_newick,
|
|
28
28
|
get_BD_events,
|
|
29
29
|
get_BDEI_events,
|
|
30
30
|
get_BDSS_events,
|
|
@@ -50,13 +50,13 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
50
50
|
data_type: Literal[DataType.TREES] = DataType.TREES
|
|
51
51
|
mutations: list[cfg.Mutation] = Field(default_factory=lambda: [])
|
|
52
52
|
rates_to_log: list[EventType] | None = None
|
|
53
|
-
|
|
54
|
-
max_tips: cfg.Integer | None = None
|
|
53
|
+
n_tips: cfg.Integer | None = None
|
|
55
54
|
max_time: cfg.Scalar = np.inf
|
|
56
55
|
init_state: str | None = None
|
|
57
56
|
sampling_probability_at_present: cfg.Scalar = 0.0
|
|
58
57
|
timeout: float = np.inf
|
|
59
58
|
node_features: list[Feature] | None = None
|
|
59
|
+
acceptance_criterion: str | None = None
|
|
60
60
|
|
|
61
61
|
@abstractmethod
|
|
62
62
|
def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
@@ -71,11 +71,19 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
71
71
|
)
|
|
72
72
|
events = self._get_events(data)
|
|
73
73
|
states = {e.state for e in events}
|
|
74
|
-
events += mutations(
|
|
74
|
+
events += mutations(
|
|
75
|
+
self.mutations, data, states, self.rates_to_log, default_rng(seed)
|
|
76
|
+
)
|
|
77
|
+
acceptance_criterion: None | Callable[[Tree], bool] = (
|
|
78
|
+
None
|
|
79
|
+
if self.acceptance_criterion is None
|
|
80
|
+
else lambda tree: eval(
|
|
81
|
+
self.acceptance_criterion, {}, {"tree": tree} # pyright: ignore
|
|
82
|
+
)
|
|
83
|
+
)
|
|
75
84
|
return simulate_tree(
|
|
76
85
|
events=events,
|
|
77
|
-
|
|
78
|
-
max_tips=None if self.max_tips is None else integer(self.max_tips, data),
|
|
86
|
+
n_tips=None if self.n_tips is None else integer(self.n_tips, data),
|
|
79
87
|
max_time=scalar(self.max_time, data),
|
|
80
88
|
init_state=init_state,
|
|
81
89
|
sampling_probability_at_present=scalar(
|
|
@@ -83,6 +91,7 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
83
91
|
),
|
|
84
92
|
seed=seed,
|
|
85
93
|
timeout=self.timeout,
|
|
94
|
+
acceptance_criterion=acceptance_criterion,
|
|
86
95
|
)
|
|
87
96
|
|
|
88
97
|
def generate_one(
|
|
@@ -157,7 +166,7 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
157
166
|
class ContactTracingTreeDatasetGenerator(TreeDatasetGenerator):
|
|
158
167
|
max_notified_contacts: cfg.Integer = 1
|
|
159
168
|
notification_probability: cfg.SkylineParameter = 0.0
|
|
160
|
-
sampling_rate_after_notification: cfg.SkylineParameter =
|
|
169
|
+
sampling_rate_after_notification: cfg.SkylineParameter = 2**32
|
|
161
170
|
samplable_states_after_notification: list[str] | None = None
|
|
162
171
|
|
|
163
172
|
@abstractmethod
|
phylogenie/io/__init__.py
CHANGED
phylogenie/mixins.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from types import MappingProxyType
|
|
2
|
+
from typing import Any, Mapping, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MetadataMixin:
|
|
6
|
+
"""A mixin that provides metadata management with dictionary-like access."""
|
|
7
|
+
|
|
8
|
+
def __init__(self) -> None:
|
|
9
|
+
self._metadata: dict[str, Any] = {}
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def metadata(self) -> Mapping[str, Any]:
|
|
13
|
+
"""Return a read-only view of all metadata."""
|
|
14
|
+
return MappingProxyType(self._metadata)
|
|
15
|
+
|
|
16
|
+
def set(self, key: str, value: Any) -> None:
|
|
17
|
+
"""Set or update a metadata value."""
|
|
18
|
+
self._metadata[key] = value
|
|
19
|
+
|
|
20
|
+
def update(self, metadata: Mapping[str, Any]) -> None:
|
|
21
|
+
"""Bulk update metadata values."""
|
|
22
|
+
self._metadata.update(metadata)
|
|
23
|
+
|
|
24
|
+
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
|
25
|
+
"""Get a metadata value, returning `default` if not found."""
|
|
26
|
+
return self._metadata.get(key, default)
|
|
27
|
+
|
|
28
|
+
def delete(self, key: str) -> None:
|
|
29
|
+
"""Delete a metadata if it exists, else do nothing."""
|
|
30
|
+
self._metadata.pop(key, None)
|
|
31
|
+
|
|
32
|
+
def clear(self) -> None:
|
|
33
|
+
"""Remove all metadata."""
|
|
34
|
+
self._metadata.clear()
|
|
35
|
+
|
|
36
|
+
# Dict-like behavior
|
|
37
|
+
def __getitem__(self, key: str) -> Any:
|
|
38
|
+
return self._metadata[key]
|
|
39
|
+
|
|
40
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
41
|
+
self._metadata[key] = value
|
|
@@ -19,7 +19,20 @@ from phylogenie.treesimulator.events import (
|
|
|
19
19
|
)
|
|
20
20
|
from phylogenie.treesimulator.features import Feature, set_features
|
|
21
21
|
from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
|
|
22
|
+
from phylogenie.treesimulator.io import dump_newick, load_newick, load_nexus
|
|
22
23
|
from phylogenie.treesimulator.model import get_node_state
|
|
24
|
+
from phylogenie.treesimulator.tree import Tree
|
|
25
|
+
from phylogenie.treesimulator.utils import (
|
|
26
|
+
compute_mean_leaf_pairwise_distance,
|
|
27
|
+
compute_sackin_index,
|
|
28
|
+
get_distance,
|
|
29
|
+
get_mrca,
|
|
30
|
+
get_node_depth_levels,
|
|
31
|
+
get_node_depths,
|
|
32
|
+
get_node_height_levels,
|
|
33
|
+
get_node_heights,
|
|
34
|
+
get_node_leaf_counts,
|
|
35
|
+
)
|
|
23
36
|
|
|
24
37
|
__all__ = [
|
|
25
38
|
"Birth",
|
|
@@ -38,10 +51,23 @@ __all__ = [
|
|
|
38
51
|
"get_contact_tracing_events",
|
|
39
52
|
"get_epidemiological_events",
|
|
40
53
|
"get_FBD_events",
|
|
41
|
-
"generate_trees",
|
|
42
|
-
"simulate_tree",
|
|
43
54
|
"get_mutation_id",
|
|
44
|
-
"get_node_state",
|
|
45
55
|
"Feature",
|
|
46
56
|
"set_features",
|
|
57
|
+
"simulate_tree",
|
|
58
|
+
"dump_newick",
|
|
59
|
+
"load_newick",
|
|
60
|
+
"load_nexus",
|
|
61
|
+
"generate_trees",
|
|
62
|
+
"get_node_state",
|
|
63
|
+
"Tree",
|
|
64
|
+
"compute_mean_leaf_pairwise_distance",
|
|
65
|
+
"compute_sackin_index",
|
|
66
|
+
"get_distance",
|
|
67
|
+
"get_mrca",
|
|
68
|
+
"get_node_depth_levels",
|
|
69
|
+
"get_node_depths",
|
|
70
|
+
"get_node_height_levels",
|
|
71
|
+
"get_node_heights",
|
|
72
|
+
"get_node_leaf_counts",
|
|
47
73
|
]
|
|
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
5
|
from numpy.random import Generator
|
|
7
6
|
|
|
8
7
|
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
@@ -32,8 +31,6 @@ class Event(ABC):
|
|
|
32
31
|
def get_propensity(self, model: Model, time: float) -> float:
|
|
33
32
|
n_individuals = model.count_individuals(self.state)
|
|
34
33
|
rate = self.rate.get_value_at_time(time)
|
|
35
|
-
if rate == np.inf and not n_individuals:
|
|
36
|
-
return 0
|
|
37
34
|
return rate * n_individuals
|
|
38
35
|
|
|
39
36
|
@abstractmethod
|
|
@@ -2,7 +2,6 @@ from collections import defaultdict
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
5
|
from numpy.random import Generator
|
|
7
6
|
|
|
8
7
|
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
@@ -32,10 +31,10 @@ class BirthWithContactTracing(Event):
|
|
|
32
31
|
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
33
32
|
individual = self.draw_individual(model, rng)
|
|
34
33
|
new_individual = model.birth_from(individual, self.child_state, time)
|
|
35
|
-
if CONTACTS_KEY not in model.
|
|
36
|
-
model
|
|
37
|
-
model
|
|
38
|
-
model
|
|
34
|
+
if CONTACTS_KEY not in model.metadata:
|
|
35
|
+
model[CONTACTS_KEY] = defaultdict(list)
|
|
36
|
+
model[CONTACTS_KEY][individual].append(new_individual)
|
|
37
|
+
model[CONTACTS_KEY][new_individual].append(individual)
|
|
39
38
|
|
|
40
39
|
def __repr__(self) -> str:
|
|
41
40
|
return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
|
|
@@ -59,9 +58,9 @@ class SamplingWithContactTracing(Event):
|
|
|
59
58
|
individual = self.draw_individual(model, rng)
|
|
60
59
|
model.sample(individual, time, True)
|
|
61
60
|
population = model.get_population()
|
|
62
|
-
if CONTACTS_KEY not in model.
|
|
61
|
+
if CONTACTS_KEY not in model.metadata:
|
|
63
62
|
return
|
|
64
|
-
contacts = model
|
|
63
|
+
contacts = model[CONTACTS_KEY][individual]
|
|
65
64
|
for contact in contacts[-self.max_notified_contacts :]:
|
|
66
65
|
if contact in population:
|
|
67
66
|
state = model.get_state(contact)
|
|
@@ -76,8 +75,8 @@ class SamplingWithContactTracing(Event):
|
|
|
76
75
|
def get_contact_tracing_events(
|
|
77
76
|
events: Sequence[Event],
|
|
78
77
|
max_notified_contacts: int = 1,
|
|
79
|
-
notification_probability: SkylineParameterLike =
|
|
80
|
-
sampling_rate_after_notification: SkylineParameterLike =
|
|
78
|
+
notification_probability: SkylineParameterLike = 0.0,
|
|
79
|
+
sampling_rate_after_notification: SkylineParameterLike = 2**32,
|
|
81
80
|
samplable_states_after_notification: list[str] | None = None,
|
|
82
81
|
) -> list[Event]:
|
|
83
82
|
ct_events: list[Event] = []
|