phylogenie 2.1.0__py3-none-any.whl → 2.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +20 -1
- phylogenie/generators/configs.py +11 -12
- phylogenie/generators/dataset.py +6 -9
- phylogenie/generators/factories.py +19 -10
- phylogenie/generators/trees.py +28 -7
- phylogenie/io.py +22 -10
- phylogenie/tree.py +42 -8
- phylogenie/treesimulator/__init__.py +16 -0
- phylogenie/treesimulator/events/__init__.py +39 -0
- phylogenie/treesimulator/events/contact_tracing.py +125 -0
- phylogenie/treesimulator/{events.py → events/core.py} +29 -119
- phylogenie/treesimulator/events/mutations.py +105 -0
- phylogenie/treesimulator/gillespie.py +10 -15
- phylogenie/treesimulator/model.py +48 -20
- phylogenie/typings.py +0 -1
- phylogenie/utils.py +17 -0
- {phylogenie-2.1.0.dist-info → phylogenie-2.1.2.dist-info}/METADATA +1 -2
- phylogenie-2.1.2.dist-info/RECORD +32 -0
- phylogenie-2.1.0.dist-info/RECORD +0 -28
- {phylogenie-2.1.0.dist-info → phylogenie-2.1.2.dist-info}/LICENSE.txt +0 -0
- {phylogenie-2.1.0.dist-info → phylogenie-2.1.2.dist-info}/WHEEL +0 -0
- {phylogenie-2.1.0.dist-info → phylogenie-2.1.2.dist-info}/entry_points.txt +0 -0
phylogenie/__init__.py
CHANGED
|
@@ -10,7 +10,7 @@ from phylogenie.generators import (
|
|
|
10
10
|
FBDTreeDatasetGenerator,
|
|
11
11
|
TreeDatasetGeneratorConfig,
|
|
12
12
|
)
|
|
13
|
-
from phylogenie.io import load_fasta, load_newick
|
|
13
|
+
from phylogenie.io import dump_newick, load_fasta, load_newick
|
|
14
14
|
from phylogenie.msa import MSA
|
|
15
15
|
from phylogenie.skyline import (
|
|
16
16
|
SkylineMatrix,
|
|
@@ -26,12 +26,21 @@ from phylogenie.skyline import (
|
|
|
26
26
|
)
|
|
27
27
|
from phylogenie.tree import Tree
|
|
28
28
|
from phylogenie.treesimulator import (
|
|
29
|
+
Birth,
|
|
30
|
+
BirthWithContactTracing,
|
|
31
|
+
Death,
|
|
29
32
|
Event,
|
|
33
|
+
Migration,
|
|
34
|
+
Mutation,
|
|
35
|
+
MutationTargetType,
|
|
36
|
+
Sampling,
|
|
37
|
+
SamplingWithContactTracing,
|
|
30
38
|
generate_trees,
|
|
31
39
|
get_BD_events,
|
|
32
40
|
get_BDEI_events,
|
|
33
41
|
get_BDSS_events,
|
|
34
42
|
get_canonical_events,
|
|
43
|
+
get_contact_tracing_events,
|
|
35
44
|
get_epidemiological_events,
|
|
36
45
|
get_FBD_events,
|
|
37
46
|
simulate_tree,
|
|
@@ -59,15 +68,25 @@ __all__ = [
|
|
|
59
68
|
"skyline_vector",
|
|
60
69
|
"Tree",
|
|
61
70
|
"TreeDatasetGeneratorConfig",
|
|
71
|
+
"Birth",
|
|
72
|
+
"BirthWithContactTracing",
|
|
73
|
+
"Death",
|
|
62
74
|
"Event",
|
|
75
|
+
"Migration",
|
|
76
|
+
"Mutation",
|
|
77
|
+
"MutationTargetType",
|
|
78
|
+
"Sampling",
|
|
79
|
+
"SamplingWithContactTracing",
|
|
63
80
|
"get_BD_events",
|
|
64
81
|
"get_BDEI_events",
|
|
65
82
|
"get_BDSS_events",
|
|
66
83
|
"get_canonical_events",
|
|
84
|
+
"get_contact_tracing_events",
|
|
67
85
|
"get_epidemiological_events",
|
|
68
86
|
"get_FBD_events",
|
|
69
87
|
"generate_trees",
|
|
70
88
|
"simulate_tree",
|
|
89
|
+
"dump_newick",
|
|
71
90
|
"load_fasta",
|
|
72
91
|
"load_newick",
|
|
73
92
|
"MSA",
|
phylogenie/generators/configs.py
CHANGED
|
@@ -1,12 +1,6 @@
|
|
|
1
|
-
from pydantic import BaseModel, ConfigDict
|
|
2
|
-
|
|
3
1
|
import phylogenie.typings as pgt
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class Distribution(BaseModel):
|
|
7
|
-
type: str
|
|
8
|
-
model_config = ConfigDict(extra="allow")
|
|
9
|
-
|
|
2
|
+
from phylogenie.treesimulator import MutationTargetType
|
|
3
|
+
from phylogenie.utils import Distribution, StrictBaseModel
|
|
10
4
|
|
|
11
5
|
Integer = str | int
|
|
12
6
|
Scalar = str | pgt.Scalar
|
|
@@ -15,10 +9,6 @@ OneOrManyScalars = Scalar | pgt.Many[Scalar]
|
|
|
15
9
|
OneOrMany2DScalars = Scalar | pgt.Many2D[Scalar]
|
|
16
10
|
|
|
17
11
|
|
|
18
|
-
class StrictBaseModel(BaseModel):
|
|
19
|
-
model_config = ConfigDict(extra="forbid")
|
|
20
|
-
|
|
21
|
-
|
|
22
12
|
class SkylineParameterModel(StrictBaseModel):
|
|
23
13
|
value: ManyScalars
|
|
24
14
|
change_times: ManyScalars
|
|
@@ -37,3 +27,12 @@ class SkylineMatrixModel(StrictBaseModel):
|
|
|
37
27
|
SkylineParameter = Scalar | SkylineParameterModel
|
|
38
28
|
SkylineVector = str | pgt.Scalar | pgt.Many[SkylineParameter] | SkylineVectorModel
|
|
39
29
|
SkylineMatrix = str | pgt.Scalar | pgt.Many[SkylineVector] | SkylineMatrixModel | None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Event(StrictBaseModel):
|
|
33
|
+
states: str | list[str] | None = None
|
|
34
|
+
rate: SkylineParameter
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Mutation(Event):
|
|
38
|
+
rate_scalers: dict[MutationTargetType, Distribution]
|
phylogenie/generators/dataset.py
CHANGED
|
@@ -10,8 +10,8 @@ import pandas as pd
|
|
|
10
10
|
from numpy.random import Generator, default_rng
|
|
11
11
|
from tqdm import tqdm
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
from phylogenie.
|
|
13
|
+
from phylogenie.generators.factories import distribution
|
|
14
|
+
from phylogenie.utils import Distribution, StrictBaseModel
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class DataType(str, Enum):
|
|
@@ -23,12 +23,12 @@ DATA_DIRNAME = "data"
|
|
|
23
23
|
METADATA_FILENAME = "metadata.csv"
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class DatasetGenerator(ABC,
|
|
26
|
+
class DatasetGenerator(ABC, StrictBaseModel):
|
|
27
27
|
output_dir: str = "phylogenie-outputs"
|
|
28
28
|
n_samples: int | dict[str, int] = 1
|
|
29
29
|
n_jobs: int = -1
|
|
30
30
|
seed: int | None = None
|
|
31
|
-
context: dict[str,
|
|
31
|
+
context: dict[str, Distribution] | None = None
|
|
32
32
|
|
|
33
33
|
@abstractmethod
|
|
34
34
|
def _generate_one(
|
|
@@ -56,11 +56,8 @@ class DatasetGenerator(ABC, cfg.StrictBaseModel):
|
|
|
56
56
|
data: list[dict[str, Any]] = [{} for _ in range(n_samples)]
|
|
57
57
|
if self.context is not None:
|
|
58
58
|
for d, (k, v) in product(data, self.context.items()):
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
if isinstance(arg_value, str):
|
|
62
|
-
args[arg_name] = eval_expression(arg_value, d)
|
|
63
|
-
d[k] = np.array(getattr(rng, v.type)(**args)).tolist()
|
|
59
|
+
dist = distribution(v, d)
|
|
60
|
+
d[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
|
|
64
61
|
df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
|
|
65
62
|
df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
|
|
66
63
|
|
|
@@ -14,9 +14,10 @@ from phylogenie.skyline import (
|
|
|
14
14
|
SkylineVector,
|
|
15
15
|
SkylineVectorCoercible,
|
|
16
16
|
)
|
|
17
|
+
from phylogenie.utils import Distribution
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def
|
|
20
|
+
def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
20
21
|
return np.array(
|
|
21
22
|
eval(
|
|
22
23
|
expression,
|
|
@@ -29,9 +30,17 @@ def eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
|
29
30
|
).tolist()
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
|
|
34
|
+
args = x.args
|
|
35
|
+
for arg_name, arg_value in args.items():
|
|
36
|
+
if isinstance(arg_value, str):
|
|
37
|
+
args[arg_name] = _eval_expression(arg_value, data)
|
|
38
|
+
return Distribution(type=x.type, **args)
|
|
39
|
+
|
|
40
|
+
|
|
32
41
|
def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
|
|
33
42
|
if isinstance(x, str):
|
|
34
|
-
e =
|
|
43
|
+
e = _eval_expression(x, data)
|
|
35
44
|
if isinstance(e, int):
|
|
36
45
|
return e
|
|
37
46
|
raise ValueError(
|
|
@@ -42,7 +51,7 @@ def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
|
|
|
42
51
|
|
|
43
52
|
def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
|
|
44
53
|
if isinstance(x, str):
|
|
45
|
-
e =
|
|
54
|
+
e = _eval_expression(x, data)
|
|
46
55
|
if isinstance(e, pgt.Scalar):
|
|
47
56
|
return e
|
|
48
57
|
raise ValueError(
|
|
@@ -53,7 +62,7 @@ def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
|
|
|
53
62
|
|
|
54
63
|
def many_scalars(x: cfg.ManyScalars, data: dict[str, Any]) -> pgt.ManyScalars:
|
|
55
64
|
if isinstance(x, str):
|
|
56
|
-
e =
|
|
65
|
+
e = _eval_expression(x, data)
|
|
57
66
|
if tg.is_many_scalars(e):
|
|
58
67
|
return e
|
|
59
68
|
raise ValueError(
|
|
@@ -66,7 +75,7 @@ def one_or_many_scalars(
|
|
|
66
75
|
x: cfg.OneOrManyScalars, data: dict[str, Any]
|
|
67
76
|
) -> pgt.OneOrManyScalars:
|
|
68
77
|
if isinstance(x, str):
|
|
69
|
-
e =
|
|
78
|
+
e = _eval_expression(x, data)
|
|
70
79
|
if tg.is_one_or_many_scalars(e):
|
|
71
80
|
return e
|
|
72
81
|
raise ValueError(
|
|
@@ -92,7 +101,7 @@ def skyline_vector(
|
|
|
92
101
|
x: cfg.SkylineVector, data: dict[str, Any]
|
|
93
102
|
) -> SkylineVectorCoercible:
|
|
94
103
|
if isinstance(x, str):
|
|
95
|
-
e =
|
|
104
|
+
e = _eval_expression(x, data)
|
|
96
105
|
if tg.is_one_or_many_scalars(e):
|
|
97
106
|
return e
|
|
98
107
|
raise ValueError(
|
|
@@ -107,7 +116,7 @@ def skyline_vector(
|
|
|
107
116
|
|
|
108
117
|
change_times = many_scalars(x.change_times, data)
|
|
109
118
|
if isinstance(x.value, str):
|
|
110
|
-
e =
|
|
119
|
+
e = _eval_expression(x.value, data)
|
|
111
120
|
if tg.is_many_one_or_many_scalars(e):
|
|
112
121
|
value = e
|
|
113
122
|
else:
|
|
@@ -135,7 +144,7 @@ def one_or_many_2D_scalars(
|
|
|
135
144
|
x: cfg.OneOrMany2DScalars, data: dict[str, Any]
|
|
136
145
|
) -> pgt.OneOrMany2DScalars:
|
|
137
146
|
if isinstance(x, str):
|
|
138
|
-
e =
|
|
147
|
+
e = _eval_expression(x, data)
|
|
139
148
|
if tg.is_one_or_many_2D_scalars(e):
|
|
140
149
|
return e
|
|
141
150
|
raise ValueError(
|
|
@@ -153,7 +162,7 @@ def skyline_matrix(
|
|
|
153
162
|
return None
|
|
154
163
|
|
|
155
164
|
if isinstance(x, str):
|
|
156
|
-
e =
|
|
165
|
+
e = _eval_expression(x, data)
|
|
157
166
|
if tg.is_one_or_many_2D_scalars(e):
|
|
158
167
|
return e
|
|
159
168
|
raise ValueError(
|
|
@@ -168,7 +177,7 @@ def skyline_matrix(
|
|
|
168
177
|
|
|
169
178
|
change_times = many_scalars(x.change_times, data)
|
|
170
179
|
if isinstance(x.value, str):
|
|
171
|
-
e =
|
|
180
|
+
e = _eval_expression(x.value, data)
|
|
172
181
|
if tg.is_many_one_or_many_2D_scalars(e):
|
|
173
182
|
value = e
|
|
174
183
|
else:
|
phylogenie/generators/trees.py
CHANGED
|
@@ -9,6 +9,7 @@ from pydantic import Field
|
|
|
9
9
|
import phylogenie.generators.configs as cfg
|
|
10
10
|
from phylogenie.generators.dataset import DatasetGenerator, DataType
|
|
11
11
|
from phylogenie.generators.factories import (
|
|
12
|
+
distribution,
|
|
12
13
|
integer,
|
|
13
14
|
scalar,
|
|
14
15
|
skyline_matrix,
|
|
@@ -19,6 +20,7 @@ from phylogenie.io import dump_newick
|
|
|
19
20
|
from phylogenie.tree import Tree
|
|
20
21
|
from phylogenie.treesimulator import (
|
|
21
22
|
Event,
|
|
23
|
+
Mutation,
|
|
22
24
|
get_BD_events,
|
|
23
25
|
get_BDEI_events,
|
|
24
26
|
get_BDSS_events,
|
|
@@ -46,7 +48,6 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
46
48
|
max_time: cfg.Scalar = np.inf
|
|
47
49
|
init_state: str | None = None
|
|
48
50
|
sampling_probability_at_present: cfg.Scalar = 0.0
|
|
49
|
-
max_tries: int | None = None
|
|
50
51
|
|
|
51
52
|
@abstractmethod
|
|
52
53
|
def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
@@ -67,7 +68,6 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
67
68
|
sampling_probability_at_present=scalar(
|
|
68
69
|
self.sampling_probability_at_present, data
|
|
69
70
|
),
|
|
70
|
-
max_tries=self.max_tries,
|
|
71
71
|
seed=int(rng.integers(2**32)),
|
|
72
72
|
)
|
|
73
73
|
|
|
@@ -127,11 +127,12 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
127
127
|
)
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
class
|
|
130
|
+
class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
|
|
131
131
|
max_notified_contacts: cfg.Integer = 1
|
|
132
132
|
notification_probability: cfg.SkylineParameter = 0.0
|
|
133
133
|
sampling_rate_after_notification: cfg.SkylineParameter = np.inf
|
|
134
134
|
samplable_states_after_notification: list[str] | None = None
|
|
135
|
+
mutations: tuple[cfg.Mutation, ...] = Field(default_factory=tuple)
|
|
135
136
|
|
|
136
137
|
@abstractmethod
|
|
137
138
|
def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
@@ -150,10 +151,30 @@ class TreeDatasetGeneratorWithContactTracing(TreeDatasetGenerator):
|
|
|
150
151
|
),
|
|
151
152
|
samplable_states_after_notification=self.samplable_states_after_notification,
|
|
152
153
|
)
|
|
154
|
+
all_states = list({e.state for e in events})
|
|
155
|
+
for mutation in self.mutations:
|
|
156
|
+
states = mutation.states
|
|
157
|
+
if isinstance(states, str):
|
|
158
|
+
states = [states]
|
|
159
|
+
elif states is None:
|
|
160
|
+
states = all_states
|
|
161
|
+
for state in states:
|
|
162
|
+
if state not in all_states:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Mutation state '{state}' is not found in states {all_states}."
|
|
165
|
+
)
|
|
166
|
+
rate_scalers = {
|
|
167
|
+
t: distribution(r, data) for t, r in mutation.rate_scalers.items()
|
|
168
|
+
}
|
|
169
|
+
events.append(
|
|
170
|
+
Mutation(
|
|
171
|
+
state, skyline_parameter(mutation.rate, data), rate_scalers
|
|
172
|
+
)
|
|
173
|
+
)
|
|
153
174
|
return events
|
|
154
175
|
|
|
155
176
|
|
|
156
|
-
class EpidemiologicalTreeDatasetGenerator(
|
|
177
|
+
class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
157
178
|
parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
|
|
158
179
|
ParameterizationType.EPIDEMIOLOGICAL
|
|
159
180
|
)
|
|
@@ -179,7 +200,7 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing
|
|
|
179
200
|
)
|
|
180
201
|
|
|
181
202
|
|
|
182
|
-
class BDTreeDatasetGenerator(
|
|
203
|
+
class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
183
204
|
parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
|
|
184
205
|
reproduction_number: cfg.SkylineParameter
|
|
185
206
|
infectious_period: cfg.SkylineParameter
|
|
@@ -193,7 +214,7 @@ class BDTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
|
193
214
|
)
|
|
194
215
|
|
|
195
216
|
|
|
196
|
-
class BDEITreeDatasetGenerator(
|
|
217
|
+
class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
197
218
|
parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
|
|
198
219
|
reproduction_number: cfg.SkylineParameter
|
|
199
220
|
infectious_period: cfg.SkylineParameter
|
|
@@ -209,7 +230,7 @@ class BDEITreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
|
209
230
|
)
|
|
210
231
|
|
|
211
232
|
|
|
212
|
-
class BDSSTreeDatasetGenerator(
|
|
233
|
+
class BDSSTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
213
234
|
parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
|
|
214
235
|
reproduction_number: cfg.SkylineParameter
|
|
215
236
|
infectious_period: cfg.SkylineParameter
|
phylogenie/io.py
CHANGED
|
@@ -24,21 +24,28 @@ def _parse_newick(newick: str) -> Tree:
|
|
|
24
24
|
stack.append(current_nodes)
|
|
25
25
|
current_nodes = []
|
|
26
26
|
else:
|
|
27
|
-
id = _parse_chars([":", ",", ")", ";"])
|
|
27
|
+
id = _parse_chars([":", ",", ")", ";", "["])
|
|
28
28
|
branch_length = None
|
|
29
29
|
if newick[i] == ":":
|
|
30
30
|
i += 1
|
|
31
|
-
branch_length = _parse_chars([",", ")", ";"])
|
|
31
|
+
branch_length = float(_parse_chars([",", ")", ";", "["]))
|
|
32
32
|
|
|
33
|
-
current_node = Tree(
|
|
34
|
-
id=id,
|
|
35
|
-
branch_length=(None if branch_length is None else float(branch_length)),
|
|
36
|
-
)
|
|
33
|
+
current_node = Tree(id, branch_length)
|
|
37
34
|
for node in current_children:
|
|
38
35
|
current_node.add_child(node)
|
|
39
36
|
current_children = []
|
|
40
37
|
current_nodes.append(current_node)
|
|
41
38
|
|
|
39
|
+
if newick[i] == "[":
|
|
40
|
+
i += 1
|
|
41
|
+
features = _parse_chars(["]"]).split(":")
|
|
42
|
+
i += 1
|
|
43
|
+
if features[0] != "&&NHX":
|
|
44
|
+
raise ValueError(f"Expected '&&NHX' for node features.")
|
|
45
|
+
for feature in features[1:]:
|
|
46
|
+
key, value = feature.split("=", 1)
|
|
47
|
+
current_node.set(key, eval(value))
|
|
48
|
+
|
|
42
49
|
if newick[i] == ")":
|
|
43
50
|
current_children = current_nodes
|
|
44
51
|
current_nodes = stack.pop()
|
|
@@ -47,7 +54,7 @@ def _parse_newick(newick: str) -> Tree:
|
|
|
47
54
|
|
|
48
55
|
i += 1
|
|
49
56
|
|
|
50
|
-
raise ValueError("Newick string
|
|
57
|
+
raise ValueError("Newick string is invalid.")
|
|
51
58
|
|
|
52
59
|
|
|
53
60
|
def load_newick(filepath: str) -> Tree | list[Tree]:
|
|
@@ -63,6 +70,10 @@ def _to_newick(tree: Tree) -> str:
|
|
|
63
70
|
newick = f"({children_newick}){newick}"
|
|
64
71
|
if tree.branch_length is not None:
|
|
65
72
|
newick += f":{tree.branch_length}"
|
|
73
|
+
if tree.features:
|
|
74
|
+
reprs = {k: repr(v).replace("'", '"') for k, v in tree.features.items()}
|
|
75
|
+
features = [f"{k}={repr}" for k, repr in reprs.items()]
|
|
76
|
+
newick += f"[&&NHX:{':'.join(features)}]"
|
|
66
77
|
return newick
|
|
67
78
|
|
|
68
79
|
|
|
@@ -83,13 +94,14 @@ def load_fasta(
|
|
|
83
94
|
if not line.startswith(">"):
|
|
84
95
|
raise ValueError(f"Invalid FASTA format: expected '>', got '{line[0]}'")
|
|
85
96
|
id = line[1:].strip()
|
|
97
|
+
time = None
|
|
86
98
|
if extract_time_from_id is not None:
|
|
87
99
|
time = extract_time_from_id(id)
|
|
88
|
-
|
|
100
|
+
elif "|" in id:
|
|
89
101
|
try:
|
|
90
102
|
time = float(id.split("|")[-1])
|
|
91
|
-
except:
|
|
92
|
-
|
|
103
|
+
except ValueError:
|
|
104
|
+
pass
|
|
93
105
|
chars = next(f).strip()
|
|
94
106
|
sequences.append(Sequence(id, chars, time))
|
|
95
107
|
return MSA(sequences)
|
phylogenie/tree.py
CHANGED
|
@@ -1,18 +1,41 @@
|
|
|
1
1
|
from collections.abc import Iterator
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class Tree:
|
|
5
6
|
def __init__(self, id: str = "", branch_length: float | None = None):
|
|
6
7
|
self.id = id
|
|
7
8
|
self.branch_length = branch_length
|
|
8
|
-
self.
|
|
9
|
-
self.
|
|
9
|
+
self._parent: Tree | None = None
|
|
10
|
+
self._children: list[Tree] = []
|
|
11
|
+
self._features: dict[str, Any] = {}
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def children(self) -> tuple["Tree", ...]:
|
|
15
|
+
return tuple(self._children)
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def parent(self) -> "Tree | None":
|
|
19
|
+
return self._parent
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def features(self) -> dict[str, Any]:
|
|
23
|
+
return self._features
|
|
10
24
|
|
|
11
25
|
def add_child(self, child: "Tree") -> "Tree":
|
|
12
|
-
child.
|
|
13
|
-
self.
|
|
26
|
+
child._parent = self
|
|
27
|
+
self._children.append(child)
|
|
14
28
|
return self
|
|
15
29
|
|
|
30
|
+
def remove_child(self, child: "Tree") -> None:
|
|
31
|
+
self._children.remove(child)
|
|
32
|
+
child._parent = None
|
|
33
|
+
|
|
34
|
+
def set_parent(self, node: "Tree | None"):
|
|
35
|
+
self._parent = node
|
|
36
|
+
if node is not None:
|
|
37
|
+
node._children.append(self)
|
|
38
|
+
|
|
16
39
|
def preorder_traversal(self) -> Iterator["Tree"]:
|
|
17
40
|
yield self
|
|
18
41
|
for child in self.children:
|
|
@@ -29,6 +52,9 @@ class Tree:
|
|
|
29
52
|
return node
|
|
30
53
|
raise ValueError(f"Node with id {id} not found.")
|
|
31
54
|
|
|
55
|
+
def is_leaf(self) -> bool:
|
|
56
|
+
return not self.children
|
|
57
|
+
|
|
32
58
|
def get_leaves(self) -> list["Tree"]:
|
|
33
59
|
return [node for node in self if not node.children]
|
|
34
60
|
|
|
@@ -38,11 +64,19 @@ class Tree:
|
|
|
38
64
|
raise ValueError(f"Branch length of node {self.id} is not set.")
|
|
39
65
|
return self.branch_length + parent_time
|
|
40
66
|
|
|
41
|
-
def
|
|
42
|
-
|
|
67
|
+
def set(self, key: str, value: Any) -> None:
|
|
68
|
+
self._features[key] = value
|
|
69
|
+
|
|
70
|
+
def get(self, key: str) -> Any:
|
|
71
|
+
return self._features.get(key)
|
|
72
|
+
|
|
73
|
+
def delete(self, key: str) -> None:
|
|
74
|
+
del self._features[key]
|
|
43
75
|
|
|
44
|
-
def copy(self)
|
|
76
|
+
def copy(self):
|
|
45
77
|
new_tree = Tree(self.id, self.branch_length)
|
|
78
|
+
for key, value in self._features.items():
|
|
79
|
+
new_tree.set(key, value)
|
|
46
80
|
for child in self.children:
|
|
47
81
|
new_tree.add_child(child.copy())
|
|
48
82
|
return new_tree
|
|
@@ -51,4 +85,4 @@ class Tree:
|
|
|
51
85
|
return self.preorder_traversal()
|
|
52
86
|
|
|
53
87
|
def __repr__(self) -> str:
|
|
54
|
-
return f"TreeNode(id='{self.id}', branch_length={self.branch_length})"
|
|
88
|
+
return f"TreeNode(id='{self.id}', branch_length={self.branch_length}, features={self.features})"
|
|
@@ -1,5 +1,13 @@
|
|
|
1
1
|
from phylogenie.treesimulator.events import (
|
|
2
|
+
Birth,
|
|
3
|
+
BirthWithContactTracing,
|
|
4
|
+
Death,
|
|
2
5
|
Event,
|
|
6
|
+
Migration,
|
|
7
|
+
Mutation,
|
|
8
|
+
MutationTargetType,
|
|
9
|
+
Sampling,
|
|
10
|
+
SamplingWithContactTracing,
|
|
3
11
|
get_BD_events,
|
|
4
12
|
get_BDEI_events,
|
|
5
13
|
get_BDSS_events,
|
|
@@ -11,7 +19,15 @@ from phylogenie.treesimulator.events import (
|
|
|
11
19
|
from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
|
|
12
20
|
|
|
13
21
|
__all__ = [
|
|
22
|
+
"Birth",
|
|
23
|
+
"BirthWithContactTracing",
|
|
24
|
+
"Death",
|
|
14
25
|
"Event",
|
|
26
|
+
"Migration",
|
|
27
|
+
"Mutation",
|
|
28
|
+
"MutationTargetType",
|
|
29
|
+
"Sampling",
|
|
30
|
+
"SamplingWithContactTracing",
|
|
15
31
|
"get_BD_events",
|
|
16
32
|
"get_BDEI_events",
|
|
17
33
|
"get_BDSS_events",
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from phylogenie.treesimulator.events.contact_tracing import (
|
|
2
|
+
BirthWithContactTracing,
|
|
3
|
+
SamplingWithContactTracing,
|
|
4
|
+
get_contact_tracing_events,
|
|
5
|
+
)
|
|
6
|
+
from phylogenie.treesimulator.events.core import (
|
|
7
|
+
Birth,
|
|
8
|
+
Death,
|
|
9
|
+
Event,
|
|
10
|
+
Migration,
|
|
11
|
+
Sampling,
|
|
12
|
+
get_BD_events,
|
|
13
|
+
get_BDEI_events,
|
|
14
|
+
get_BDSS_events,
|
|
15
|
+
get_canonical_events,
|
|
16
|
+
get_epidemiological_events,
|
|
17
|
+
get_FBD_events,
|
|
18
|
+
)
|
|
19
|
+
from phylogenie.treesimulator.events.mutations import Mutation
|
|
20
|
+
from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"Birth",
|
|
24
|
+
"BirthWithContactTracing",
|
|
25
|
+
"Death",
|
|
26
|
+
"Event",
|
|
27
|
+
"Migration",
|
|
28
|
+
"Mutation",
|
|
29
|
+
"Sampling",
|
|
30
|
+
"SamplingWithContactTracing",
|
|
31
|
+
"MutationTargetType",
|
|
32
|
+
"get_BD_events",
|
|
33
|
+
"get_BDEI_events",
|
|
34
|
+
"get_BDSS_events",
|
|
35
|
+
"get_canonical_events",
|
|
36
|
+
"get_contact_tracing_events",
|
|
37
|
+
"get_epidemiological_events",
|
|
38
|
+
"get_FBD_events",
|
|
39
|
+
]
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.random import Generator
|
|
6
|
+
|
|
7
|
+
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
8
|
+
from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
|
|
9
|
+
from phylogenie.treesimulator.model import Event, Model
|
|
10
|
+
|
|
11
|
+
CT_POSTFIX = "-CT"
|
|
12
|
+
CONTACTS_KEY = "CONTACTS"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_CT_state(state: str) -> str:
|
|
16
|
+
return f"{state}{CT_POSTFIX}"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_CT_state(state: str) -> bool:
|
|
20
|
+
return state.endswith(CT_POSTFIX)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BirthWithContactTracing(Event):
|
|
24
|
+
def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
|
|
25
|
+
super().__init__(state, rate)
|
|
26
|
+
self.child_state = child_state
|
|
27
|
+
|
|
28
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
29
|
+
individual = self.draw_individual(model, rng)
|
|
30
|
+
new_individual = model.birth_from(individual, self.child_state, time)
|
|
31
|
+
if CONTACTS_KEY not in model.context:
|
|
32
|
+
model.context[CONTACTS_KEY] = defaultdict(list)
|
|
33
|
+
model.context[CONTACTS_KEY][individual].append(new_individual)
|
|
34
|
+
model.context[CONTACTS_KEY][new_individual].append(individual)
|
|
35
|
+
|
|
36
|
+
def __repr__(self) -> str:
|
|
37
|
+
return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SamplingWithContactTracing(Event):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
state: str,
|
|
44
|
+
rate: SkylineParameterLike,
|
|
45
|
+
max_notified_contacts: int,
|
|
46
|
+
notification_probability: SkylineParameterLike,
|
|
47
|
+
):
|
|
48
|
+
super().__init__(state, rate)
|
|
49
|
+
self.max_notified_contacts = max_notified_contacts
|
|
50
|
+
self.notification_probability = skyline_parameter(notification_probability)
|
|
51
|
+
|
|
52
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
53
|
+
individual = self.draw_individual(model, rng)
|
|
54
|
+
model.sample(individual, time, True)
|
|
55
|
+
population = model.get_population()
|
|
56
|
+
if CONTACTS_KEY not in model.context:
|
|
57
|
+
return
|
|
58
|
+
contacts = model.context[CONTACTS_KEY][individual]
|
|
59
|
+
for contact in contacts[-self.max_notified_contacts :]:
|
|
60
|
+
if contact in population:
|
|
61
|
+
state = model.get_state(contact)
|
|
62
|
+
p = self.notification_probability.get_value_at_time(time)
|
|
63
|
+
if not _is_CT_state(state) and rng.random() < p:
|
|
64
|
+
model.migrate(contact, _get_CT_state(state), time)
|
|
65
|
+
|
|
66
|
+
def __repr__(self) -> str:
|
|
67
|
+
return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_contact_tracing_events(
|
|
71
|
+
events: Sequence[Event],
|
|
72
|
+
max_notified_contacts: int = 1,
|
|
73
|
+
notification_probability: SkylineParameterLike = 1,
|
|
74
|
+
sampling_rate_after_notification: SkylineParameterLike = np.inf,
|
|
75
|
+
samplable_states_after_notification: list[str] | None = None,
|
|
76
|
+
) -> list[Event]:
|
|
77
|
+
ct_events: list[Event] = []
|
|
78
|
+
notification_probability = skyline_parameter(notification_probability)
|
|
79
|
+
sampling_rate_after_notification = skyline_parameter(
|
|
80
|
+
sampling_rate_after_notification
|
|
81
|
+
)
|
|
82
|
+
for event in events:
|
|
83
|
+
state, rate = event.state, event.rate
|
|
84
|
+
if isinstance(event, Migration):
|
|
85
|
+
ct_events.append(event)
|
|
86
|
+
ct_events.append(
|
|
87
|
+
Migration(_get_CT_state(state), rate, _get_CT_state(event.target_state))
|
|
88
|
+
)
|
|
89
|
+
elif isinstance(event, Birth):
|
|
90
|
+
ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
|
|
91
|
+
ct_events.append(
|
|
92
|
+
BirthWithContactTracing(_get_CT_state(state), rate, event.child_state)
|
|
93
|
+
)
|
|
94
|
+
elif isinstance(event, Sampling):
|
|
95
|
+
if not event.removal:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Contact tracing requires removal to be set for all sampling events."
|
|
98
|
+
)
|
|
99
|
+
ct_events.append(
|
|
100
|
+
SamplingWithContactTracing(
|
|
101
|
+
state, rate, max_notified_contacts, notification_probability
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
elif isinstance(event, Death):
|
|
105
|
+
ct_events.append(event)
|
|
106
|
+
else:
|
|
107
|
+
raise NotImplementedError(
|
|
108
|
+
f"Unsupported event type {type(event)} for contact tracing."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
for state in (
|
|
112
|
+
samplable_states_after_notification
|
|
113
|
+
if samplable_states_after_notification is not None
|
|
114
|
+
else {e.state for e in events}
|
|
115
|
+
):
|
|
116
|
+
ct_events.append(
|
|
117
|
+
SamplingWithContactTracing(
|
|
118
|
+
_get_CT_state(state),
|
|
119
|
+
sampling_rate_after_notification,
|
|
120
|
+
max_notified_contacts,
|
|
121
|
+
notification_probability,
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return ct_events
|