phylogenie 2.1.24__tar.gz → 2.1.26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {phylogenie-2.1.24 → phylogenie-2.1.26}/PKG-INFO +1 -1
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/__init__.py +2 -2
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/alisim.py +2 -1
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/configs.py +3 -3
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/factories.py +15 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/trees.py +14 -32
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/skyline/matrix.py +11 -7
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/skyline/parameter.py +12 -4
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/skyline/vector.py +12 -6
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/__init__.py +2 -2
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/events/__init__.py +4 -6
- phylogenie-2.1.26/phylogenie/treesimulator/events/base.py +44 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/events/contact_tracing.py +30 -14
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/events/core.py +14 -5
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/events/mutations.py +30 -45
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/features.py +1 -1
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/gillespie.py +28 -18
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/treesimulator/model.py +1 -34
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/typings.py +3 -3
- {phylogenie-2.1.24 → phylogenie-2.1.26}/pyproject.toml +1 -1
- {phylogenie-2.1.24 → phylogenie-2.1.26}/LICENSE.txt +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/README.md +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/draw.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/__init__.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/dataset.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/generators/typeguards.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/io/__init__.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/io/fasta.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/io/newick.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/io/nexus.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/main.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/models.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/msa.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/py.typed +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/skyline/__init__.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/tree.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/typeguards.py +0 -0
- {phylogenie-2.1.24 → phylogenie-2.1.26}/phylogenie/utils.py +0 -0
|
@@ -31,9 +31,9 @@ from phylogenie.treesimulator import (
|
|
|
31
31
|
BirthWithContactTracing,
|
|
32
32
|
Death,
|
|
33
33
|
Event,
|
|
34
|
+
EventType,
|
|
34
35
|
Migration,
|
|
35
36
|
Mutation,
|
|
36
|
-
MutationTargetType,
|
|
37
37
|
Sampling,
|
|
38
38
|
SamplingWithContactTracing,
|
|
39
39
|
generate_trees,
|
|
@@ -85,9 +85,9 @@ __all__ = [
|
|
|
85
85
|
"BirthWithContactTracing",
|
|
86
86
|
"Death",
|
|
87
87
|
"Event",
|
|
88
|
+
"EventType",
|
|
88
89
|
"Migration",
|
|
89
90
|
"Mutation",
|
|
90
|
-
"MutationTargetType",
|
|
91
91
|
"Sampling",
|
|
92
92
|
"SamplingWithContactTracing",
|
|
93
93
|
"get_BD_events",
|
|
@@ -65,12 +65,13 @@ class AliSimDatasetGenerator(DatasetGenerator):
|
|
|
65
65
|
while True:
|
|
66
66
|
d.update(data(context, rng))
|
|
67
67
|
try:
|
|
68
|
-
tree = self.trees.simulate_one(d, seed)
|
|
68
|
+
tree, metadata = self.trees.simulate_one(d, seed)
|
|
69
69
|
break
|
|
70
70
|
except TimeoutError:
|
|
71
71
|
print(
|
|
72
72
|
"Tree simulation timed out, retrying with different parameters..."
|
|
73
73
|
)
|
|
74
|
+
d.update(metadata)
|
|
74
75
|
|
|
75
76
|
times = get_node_depths(tree)
|
|
76
77
|
for leaf in tree.get_leaves():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import phylogenie.typings as pgt
|
|
2
2
|
from phylogenie.models import Distribution, StrictBaseModel
|
|
3
|
-
from phylogenie.treesimulator import
|
|
3
|
+
from phylogenie.treesimulator import EventType
|
|
4
4
|
|
|
5
5
|
Integer = str | int
|
|
6
6
|
Scalar = str | pgt.Scalar
|
|
@@ -30,9 +30,9 @@ SkylineMatrix = str | pgt.Scalar | pgt.Many[SkylineVector] | SkylineMatrixModel
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class Event(StrictBaseModel):
|
|
33
|
-
|
|
33
|
+
state: str | None = None
|
|
34
34
|
rate: SkylineParameter
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class Mutation(Event):
|
|
38
|
-
rate_scalers: dict[
|
|
38
|
+
rate_scalers: dict[EventType, Distribution]
|
|
@@ -17,6 +17,7 @@ from phylogenie.skyline import (
|
|
|
17
17
|
SkylineVector,
|
|
18
18
|
SkylineVectorCoercible,
|
|
19
19
|
)
|
|
20
|
+
from phylogenie.treesimulator import Mutation
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
@@ -221,6 +222,20 @@ def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
|
|
|
221
222
|
return Distribution(type=x.type, **args)
|
|
222
223
|
|
|
223
224
|
|
|
225
|
+
def mutations(
|
|
226
|
+
x: list[cfg.Mutation], data: dict[str, Any], states: set[str]
|
|
227
|
+
) -> list[Mutation]:
|
|
228
|
+
mutations: list[Mutation] = []
|
|
229
|
+
for m in x:
|
|
230
|
+
rate = skyline_parameter(m.rate, data)
|
|
231
|
+
rate_scalers = {k: distribution(v, data) for k, v in m.rate_scalers.items()}
|
|
232
|
+
if m.state is None:
|
|
233
|
+
mutations.extend(Mutation(s, rate, rate_scalers) for s in states)
|
|
234
|
+
else:
|
|
235
|
+
mutations.append(Mutation(m.state, rate, rate_scalers))
|
|
236
|
+
return mutations
|
|
237
|
+
|
|
238
|
+
|
|
224
239
|
def data(context: dict[str, Distribution] | None, rng: Generator) -> dict[str, Any]:
|
|
225
240
|
if context is None:
|
|
226
241
|
return {}
|
|
@@ -11,8 +11,8 @@ import phylogenie.generators.configs as cfg
|
|
|
11
11
|
from phylogenie.generators.dataset import DatasetGenerator, DataType
|
|
12
12
|
from phylogenie.generators.factories import (
|
|
13
13
|
data,
|
|
14
|
-
distribution,
|
|
15
14
|
integer,
|
|
15
|
+
mutations,
|
|
16
16
|
scalar,
|
|
17
17
|
skyline_matrix,
|
|
18
18
|
skyline_parameter,
|
|
@@ -24,7 +24,6 @@ from phylogenie.tree import Tree
|
|
|
24
24
|
from phylogenie.treesimulator import (
|
|
25
25
|
Event,
|
|
26
26
|
Feature,
|
|
27
|
-
Mutation,
|
|
28
27
|
get_BD_events,
|
|
29
28
|
get_BDEI_events,
|
|
30
29
|
get_BDSS_events,
|
|
@@ -48,6 +47,7 @@ class ParameterizationType(str, Enum):
|
|
|
48
47
|
|
|
49
48
|
class TreeDatasetGenerator(DatasetGenerator):
|
|
50
49
|
data_type: Literal[DataType.TREES] = DataType.TREES
|
|
50
|
+
mutations: list[cfg.Mutation] = Field(default_factory=lambda: [])
|
|
51
51
|
min_tips: cfg.Integer = 1
|
|
52
52
|
max_tips: cfg.Integer | None = None
|
|
53
53
|
max_time: cfg.Scalar = np.inf
|
|
@@ -59,14 +59,17 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
59
59
|
@abstractmethod
|
|
60
60
|
def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
61
61
|
|
|
62
|
-
def simulate_one(
|
|
62
|
+
def simulate_one(
|
|
63
|
+
self, data: dict[str, Any], seed: int | None = None
|
|
64
|
+
) -> tuple[Tree, dict[str, Any]]:
|
|
63
65
|
init_state = (
|
|
64
66
|
self.init_state
|
|
65
67
|
if self.init_state is None
|
|
66
68
|
else self.init_state.format(**data)
|
|
67
69
|
)
|
|
70
|
+
states = {e.state for e in self._get_events(data)}
|
|
68
71
|
return simulate_tree(
|
|
69
|
-
events=self._get_events(data),
|
|
72
|
+
events=self._get_events(data) + mutations(self.mutations, data, states),
|
|
70
73
|
min_tips=integer(self.min_tips, data),
|
|
71
74
|
max_tips=None if self.max_tips is None else integer(self.max_tips, data),
|
|
72
75
|
max_time=scalar(self.max_time, data),
|
|
@@ -89,14 +92,14 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
89
92
|
while True:
|
|
90
93
|
try:
|
|
91
94
|
d.update(data(context, rng))
|
|
92
|
-
tree = self.simulate_one(d, seed)
|
|
95
|
+
tree, metadata = self.simulate_one(d, seed)
|
|
93
96
|
if self.node_features is not None:
|
|
94
97
|
set_features(tree, self.node_features)
|
|
95
98
|
dump_newick(tree, f"{filename}.nwk")
|
|
96
99
|
break
|
|
97
100
|
except TimeoutError:
|
|
98
101
|
print("Simulation timed out, retrying with different parameters...")
|
|
99
|
-
return d
|
|
102
|
+
return d | metadata
|
|
100
103
|
|
|
101
104
|
|
|
102
105
|
class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
@@ -147,12 +150,11 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
147
150
|
)
|
|
148
151
|
|
|
149
152
|
|
|
150
|
-
class
|
|
153
|
+
class ContactTracingTreeDatasetGenerator(TreeDatasetGenerator):
|
|
151
154
|
max_notified_contacts: cfg.Integer = 1
|
|
152
155
|
notification_probability: cfg.SkylineParameter = 0.0
|
|
153
156
|
sampling_rate_after_notification: cfg.SkylineParameter = np.inf
|
|
154
157
|
samplable_states_after_notification: list[str] | None = None
|
|
155
|
-
mutations: tuple[cfg.Mutation, ...] = Field(default_factory=tuple)
|
|
156
158
|
|
|
157
159
|
@abstractmethod
|
|
158
160
|
def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
@@ -171,30 +173,10 @@ class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
|
|
|
171
173
|
),
|
|
172
174
|
samplable_states_after_notification=self.samplable_states_after_notification,
|
|
173
175
|
)
|
|
174
|
-
all_states = list({e.state for e in events})
|
|
175
|
-
for mutation in self.mutations:
|
|
176
|
-
states = mutation.states
|
|
177
|
-
if isinstance(states, str):
|
|
178
|
-
states = [states]
|
|
179
|
-
elif states is None:
|
|
180
|
-
states = all_states
|
|
181
|
-
for state in states:
|
|
182
|
-
if state not in all_states:
|
|
183
|
-
raise ValueError(
|
|
184
|
-
f"Mutation state '{state}' is not found in states {all_states}."
|
|
185
|
-
)
|
|
186
|
-
rate_scalers = {
|
|
187
|
-
t: distribution(r, data) for t, r in mutation.rate_scalers.items()
|
|
188
|
-
}
|
|
189
|
-
events.append(
|
|
190
|
-
Mutation(
|
|
191
|
-
state, skyline_parameter(mutation.rate, data), rate_scalers
|
|
192
|
-
)
|
|
193
|
-
)
|
|
194
176
|
return events
|
|
195
177
|
|
|
196
178
|
|
|
197
|
-
class EpidemiologicalTreeDatasetGenerator(
|
|
179
|
+
class EpidemiologicalTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
|
|
198
180
|
parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
|
|
199
181
|
ParameterizationType.EPIDEMIOLOGICAL
|
|
200
182
|
)
|
|
@@ -220,7 +202,7 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
|
220
202
|
)
|
|
221
203
|
|
|
222
204
|
|
|
223
|
-
class BDTreeDatasetGenerator(
|
|
205
|
+
class BDTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
|
|
224
206
|
parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
|
|
225
207
|
reproduction_number: cfg.SkylineParameter
|
|
226
208
|
infectious_period: cfg.SkylineParameter
|
|
@@ -234,7 +216,7 @@ class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
|
234
216
|
)
|
|
235
217
|
|
|
236
218
|
|
|
237
|
-
class BDEITreeDatasetGenerator(
|
|
219
|
+
class BDEITreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
|
|
238
220
|
parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
|
|
239
221
|
reproduction_number: cfg.SkylineParameter
|
|
240
222
|
infectious_period: cfg.SkylineParameter
|
|
@@ -250,7 +232,7 @@ class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
|
|
|
250
232
|
)
|
|
251
233
|
|
|
252
234
|
|
|
253
|
-
class BDSSTreeDatasetGenerator(
|
|
235
|
+
class BDSSTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
|
|
254
236
|
parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
|
|
255
237
|
reproduction_number: cfg.SkylineParameter
|
|
256
238
|
infectious_period: cfg.SkylineParameter
|
|
@@ -33,7 +33,7 @@ class SkylineMatrix:
|
|
|
33
33
|
):
|
|
34
34
|
if params is not None and value is None and change_times is None:
|
|
35
35
|
if is_many_skyline_vectors_like(params):
|
|
36
|
-
self.
|
|
36
|
+
self._params = [
|
|
37
37
|
p if isinstance(p, SkylineVector) else SkylineVector(p)
|
|
38
38
|
for p in params
|
|
39
39
|
]
|
|
@@ -41,7 +41,7 @@ class SkylineMatrix:
|
|
|
41
41
|
raise TypeError(
|
|
42
42
|
f"It is impossible to create a SkylineMatrix from `params` {params} of type {type(params)}. Please provide a sequence composed of SkylineVectorLike objects (a SkylineVectorLike object can either be a SkylineVector or a sequence of scalars and/or SkylineParameters)."
|
|
43
43
|
)
|
|
44
|
-
lengths = {len(p) for p in self.
|
|
44
|
+
lengths = {len(p) for p in self._params}
|
|
45
45
|
if len(lengths) > 1:
|
|
46
46
|
raise ValueError(
|
|
47
47
|
f"All `params` must have the same length to create a SkylineMatrix (got params={params} with lengths {lengths})."
|
|
@@ -57,7 +57,7 @@ class SkylineMatrix:
|
|
|
57
57
|
raise TypeError(
|
|
58
58
|
f"It is impossible to create a SkylineMatrix from `value` {value} of type {type(value)}. Please provide a nested (3D) sequence of scalar values."
|
|
59
59
|
)
|
|
60
|
-
self.
|
|
60
|
+
self._params = [
|
|
61
61
|
SkylineVector(
|
|
62
62
|
value=[matrix[i] for matrix in value], change_times=change_times
|
|
63
63
|
)
|
|
@@ -68,6 +68,10 @@ class SkylineMatrix:
|
|
|
68
68
|
"Either `params` or both `value` and `change_times` must be provided to create a SkylineMatrix."
|
|
69
69
|
)
|
|
70
70
|
|
|
71
|
+
@property
|
|
72
|
+
def params(self) -> tuple[SkylineVector, ...]:
|
|
73
|
+
return tuple(self._params)
|
|
74
|
+
|
|
71
75
|
@property
|
|
72
76
|
def n_rows(self) -> int:
|
|
73
77
|
return len(self.params)
|
|
@@ -82,14 +86,14 @@ class SkylineMatrix:
|
|
|
82
86
|
|
|
83
87
|
@property
|
|
84
88
|
def change_times(self) -> pgt.Vector1D:
|
|
85
|
-
return sorted(set([t for row in self.params for t in row.change_times]))
|
|
89
|
+
return tuple(sorted(set([t for row in self.params for t in row.change_times])))
|
|
86
90
|
|
|
87
91
|
@property
|
|
88
92
|
def value(self) -> pgt.Vector3D:
|
|
89
|
-
return
|
|
93
|
+
return tuple(self.get_value_at_time(t) for t in (0, *self.change_times))
|
|
90
94
|
|
|
91
95
|
def get_value_at_time(self, time: pgt.Scalar) -> pgt.Vector2D:
|
|
92
|
-
return
|
|
96
|
+
return tuple(param.get_value_at_time(time) for param in self.params)
|
|
93
97
|
|
|
94
98
|
def _operate(
|
|
95
99
|
self,
|
|
@@ -185,7 +189,7 @@ class SkylineMatrix:
|
|
|
185
189
|
raise TypeError(
|
|
186
190
|
f"It is impossible to set item of SkylineMatrix to value {value} of type {type(value)}. Please provide a SkylineVectorLike object (i.e., a SkylineVector or a sequence of scalars and/or SkylineParameters)."
|
|
187
191
|
)
|
|
188
|
-
self.
|
|
192
|
+
self._params[item] = skyline_vector(value, self.n_cols)
|
|
189
193
|
|
|
190
194
|
|
|
191
195
|
def skyline_matrix(
|
|
@@ -52,12 +52,20 @@ class SkylineParameter:
|
|
|
52
52
|
f"`change_times` must be non-negative (got change_times={change_times})."
|
|
53
53
|
)
|
|
54
54
|
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
55
|
+
self._value = [value[0]]
|
|
56
|
+
self._change_times: list[pgt.Scalar] = []
|
|
57
57
|
for i in range(1, len(value)):
|
|
58
58
|
if value[i] != value[i - 1]:
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
59
|
+
self._value.append(value[i])
|
|
60
|
+
self._change_times.append(change_times[i - 1])
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def value(self) -> pgt.Vector1D:
|
|
64
|
+
return tuple(self._value)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def change_times(self) -> pgt.Vector1D:
|
|
68
|
+
return tuple(self._change_times)
|
|
61
69
|
|
|
62
70
|
def get_value_at_time(self, t: pgt.Scalar) -> pgt.Scalar:
|
|
63
71
|
if t < 0:
|
|
@@ -47,7 +47,7 @@ class SkylineVector:
|
|
|
47
47
|
):
|
|
48
48
|
if params is not None and value is None and change_times is None:
|
|
49
49
|
if is_many_skyline_parameters_like(params):
|
|
50
|
-
self.
|
|
50
|
+
self._params = [skyline_parameter(param) for param in params]
|
|
51
51
|
else:
|
|
52
52
|
raise TypeError(
|
|
53
53
|
f"It is impossible to create a SkylineVector from `params` {params} of type {type(params)}. Please provide a sequence of SkylineParameterLike objects (a SkylineParameterLike object can either be a SkylineParameter or a scalar)."
|
|
@@ -63,7 +63,7 @@ class SkylineVector:
|
|
|
63
63
|
raise TypeError(
|
|
64
64
|
f"It is impossible to create a SkylineVector from `value` {value} of type {type(value)}. Please provide a nested (2D) sequence of scalar values."
|
|
65
65
|
)
|
|
66
|
-
self.
|
|
66
|
+
self._params = [
|
|
67
67
|
SkylineParameter([vector[i] for vector in value], change_times)
|
|
68
68
|
for i in range(len(value[0]))
|
|
69
69
|
]
|
|
@@ -72,20 +72,26 @@ class SkylineVector:
|
|
|
72
72
|
"Either `params` or both `value` and `change_times` must be provided to create a SkylineVector."
|
|
73
73
|
)
|
|
74
74
|
|
|
75
|
+
@property
|
|
76
|
+
def params(self) -> tuple[SkylineParameter, ...]:
|
|
77
|
+
return tuple(self._params)
|
|
78
|
+
|
|
75
79
|
@property
|
|
76
80
|
def change_times(self) -> pgt.Vector1D:
|
|
77
|
-
return
|
|
81
|
+
return tuple(
|
|
82
|
+
sorted(set(t for param in self.params for t in param.change_times))
|
|
83
|
+
)
|
|
78
84
|
|
|
79
85
|
@property
|
|
80
86
|
def value(self) -> pgt.Vector2D:
|
|
81
|
-
return
|
|
87
|
+
return tuple(self.get_value_at_time(t) for t in (0, *self.change_times))
|
|
82
88
|
|
|
83
89
|
@property
|
|
84
90
|
def N(self) -> int:
|
|
85
91
|
return len(self.params)
|
|
86
92
|
|
|
87
93
|
def get_value_at_time(self, t: pgt.Scalar) -> pgt.Vector1D:
|
|
88
|
-
return
|
|
94
|
+
return tuple(param.get_value_at_time(t) for param in self.params)
|
|
89
95
|
|
|
90
96
|
def _operate(
|
|
91
97
|
self,
|
|
@@ -154,7 +160,7 @@ class SkylineVector:
|
|
|
154
160
|
raise TypeError(
|
|
155
161
|
f"It is impossible to set item {item} of SkylineVector with value {value} of type {type(value)}. Please provide a SkylineParameterLike object (i.e., a scalar or a SkylineParameter)."
|
|
156
162
|
)
|
|
157
|
-
self.
|
|
163
|
+
self._params[item] = skyline_parameter(value)
|
|
158
164
|
|
|
159
165
|
|
|
160
166
|
def skyline_vector(x: SkylineVectorCoercible, N: int) -> SkylineVector:
|
|
@@ -3,9 +3,9 @@ from phylogenie.treesimulator.events import (
|
|
|
3
3
|
BirthWithContactTracing,
|
|
4
4
|
Death,
|
|
5
5
|
Event,
|
|
6
|
+
EventType,
|
|
6
7
|
Migration,
|
|
7
8
|
Mutation,
|
|
8
|
-
MutationTargetType,
|
|
9
9
|
Sampling,
|
|
10
10
|
SamplingWithContactTracing,
|
|
11
11
|
get_BD_events,
|
|
@@ -26,9 +26,9 @@ __all__ = [
|
|
|
26
26
|
"BirthWithContactTracing",
|
|
27
27
|
"Death",
|
|
28
28
|
"Event",
|
|
29
|
+
"EventType",
|
|
29
30
|
"Migration",
|
|
30
31
|
"Mutation",
|
|
31
|
-
"MutationTargetType",
|
|
32
32
|
"Sampling",
|
|
33
33
|
"SamplingWithContactTracing",
|
|
34
34
|
"get_BD_events",
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from phylogenie.treesimulator.events.base import Event, EventType
|
|
1
2
|
from phylogenie.treesimulator.events.contact_tracing import (
|
|
2
3
|
BirthWithContactTracing,
|
|
3
4
|
SamplingWithContactTracing,
|
|
@@ -6,7 +7,6 @@ from phylogenie.treesimulator.events.contact_tracing import (
|
|
|
6
7
|
from phylogenie.treesimulator.events.core import (
|
|
7
8
|
Birth,
|
|
8
9
|
Death,
|
|
9
|
-
Event,
|
|
10
10
|
Migration,
|
|
11
11
|
Sampling,
|
|
12
12
|
get_BD_events,
|
|
@@ -16,20 +16,17 @@ from phylogenie.treesimulator.events.core import (
|
|
|
16
16
|
get_epidemiological_events,
|
|
17
17
|
get_FBD_events,
|
|
18
18
|
)
|
|
19
|
-
from phylogenie.treesimulator.events.mutations import Mutation
|
|
20
|
-
from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
|
|
21
|
-
from phylogenie.treesimulator.events.mutations import get_mutation_id
|
|
19
|
+
from phylogenie.treesimulator.events.mutations import Mutation, get_mutation_id
|
|
22
20
|
|
|
23
21
|
__all__ = [
|
|
24
22
|
"Birth",
|
|
25
23
|
"BirthWithContactTracing",
|
|
26
24
|
"Death",
|
|
27
25
|
"Event",
|
|
26
|
+
"EventType",
|
|
28
27
|
"Migration",
|
|
29
|
-
"Mutation",
|
|
30
28
|
"Sampling",
|
|
31
29
|
"SamplingWithContactTracing",
|
|
32
|
-
"MutationTargetType",
|
|
33
30
|
"get_BD_events",
|
|
34
31
|
"get_BDEI_events",
|
|
35
32
|
"get_BDSS_events",
|
|
@@ -37,5 +34,6 @@ __all__ = [
|
|
|
37
34
|
"get_contact_tracing_events",
|
|
38
35
|
"get_epidemiological_events",
|
|
39
36
|
"get_FBD_events",
|
|
37
|
+
"Mutation",
|
|
40
38
|
"get_mutation_id",
|
|
41
39
|
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.random import Generator
|
|
7
|
+
|
|
8
|
+
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
9
|
+
from phylogenie.treesimulator.model import Model
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EventType(str, Enum):
|
|
13
|
+
BIRTH = "birth"
|
|
14
|
+
DEATH = "death"
|
|
15
|
+
MIGRATION = "migration"
|
|
16
|
+
SAMPLING = "sampling"
|
|
17
|
+
MUTATION = "mutation"
|
|
18
|
+
BIRTH_WITH_CT = "birth_with_contact_tracing"
|
|
19
|
+
SAMPLING_WITH_CT = "sampling_with_contact_tracing"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Event(ABC):
|
|
23
|
+
type: EventType
|
|
24
|
+
|
|
25
|
+
def __init__(self, state: str, rate: SkylineParameterLike):
|
|
26
|
+
self.state = state
|
|
27
|
+
self.rate = skyline_parameter(rate)
|
|
28
|
+
if any(v < 0 for v in self.rate.value):
|
|
29
|
+
raise ValueError("Event rates must be non-negative.")
|
|
30
|
+
|
|
31
|
+
def draw_individual(self, model: Model, rng: Generator) -> int:
|
|
32
|
+
return rng.choice(model.get_population(self.state))
|
|
33
|
+
|
|
34
|
+
def get_propensity(self, model: Model, time: float) -> float:
|
|
35
|
+
n_individuals = model.count_individuals(self.state)
|
|
36
|
+
rate = self.rate.get_value_at_time(time)
|
|
37
|
+
if rate == np.inf and not n_individuals:
|
|
38
|
+
return 0
|
|
39
|
+
return rate * n_individuals
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def apply(
|
|
43
|
+
self, model: Model, events: "list[Event]", time: float, rng: Generator
|
|
44
|
+
) -> dict[str, Any] | None: ...
|
|
@@ -1,31 +1,35 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
+
from copy import deepcopy
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
from numpy.random import Generator
|
|
6
7
|
|
|
7
8
|
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
9
|
+
from phylogenie.treesimulator.events.base import Event, EventType
|
|
8
10
|
from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
|
|
9
|
-
from phylogenie.treesimulator.model import
|
|
11
|
+
from phylogenie.treesimulator.model import Model
|
|
10
12
|
|
|
11
13
|
CT_POSTFIX = "-CT"
|
|
12
14
|
CONTACTS_KEY = "CONTACTS"
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def
|
|
17
|
+
def get_CT_state(state: str) -> str:
|
|
16
18
|
return f"{state}{CT_POSTFIX}"
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
def
|
|
21
|
+
def is_CT_state(state: str) -> bool:
|
|
20
22
|
return state.endswith(CT_POSTFIX)
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class BirthWithContactTracing(Event):
|
|
26
|
+
type = EventType.BIRTH_WITH_CT
|
|
27
|
+
|
|
24
28
|
def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
|
|
25
29
|
super().__init__(state, rate)
|
|
26
30
|
self.child_state = child_state
|
|
27
31
|
|
|
28
|
-
def apply(self, model: Model, time: float, rng: Generator)
|
|
32
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
29
33
|
individual = self.draw_individual(model, rng)
|
|
30
34
|
new_individual = model.birth_from(individual, self.child_state, time)
|
|
31
35
|
if CONTACTS_KEY not in model.context:
|
|
@@ -38,6 +42,8 @@ class BirthWithContactTracing(Event):
|
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
class SamplingWithContactTracing(Event):
|
|
45
|
+
type = EventType.SAMPLING_WITH_CT
|
|
46
|
+
|
|
41
47
|
def __init__(
|
|
42
48
|
self,
|
|
43
49
|
state: str,
|
|
@@ -49,7 +55,7 @@ class SamplingWithContactTracing(Event):
|
|
|
49
55
|
self.max_notified_contacts = max_notified_contacts
|
|
50
56
|
self.notification_probability = skyline_parameter(notification_probability)
|
|
51
57
|
|
|
52
|
-
def apply(self, model: Model, time: float, rng: Generator)
|
|
58
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
53
59
|
individual = self.draw_individual(model, rng)
|
|
54
60
|
model.sample(individual, time, True)
|
|
55
61
|
population = model.get_population()
|
|
@@ -60,8 +66,8 @@ class SamplingWithContactTracing(Event):
|
|
|
60
66
|
if contact in population:
|
|
61
67
|
state = model.get_state(contact)
|
|
62
68
|
p = self.notification_probability.get_value_at_time(time)
|
|
63
|
-
if not
|
|
64
|
-
model.migrate(contact,
|
|
69
|
+
if not is_CT_state(state) and rng.random() < p:
|
|
70
|
+
model.migrate(contact, get_CT_state(state), time)
|
|
65
71
|
|
|
66
72
|
def __repr__(self) -> str:
|
|
67
73
|
return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
|
|
@@ -79,17 +85,24 @@ def get_contact_tracing_events(
|
|
|
79
85
|
sampling_rate_after_notification = skyline_parameter(
|
|
80
86
|
sampling_rate_after_notification
|
|
81
87
|
)
|
|
82
|
-
for event in events:
|
|
83
|
-
state, rate = event.state, event.rate
|
|
88
|
+
for event in [deepcopy(e) for e in events]:
|
|
84
89
|
if isinstance(event, Migration):
|
|
85
90
|
ct_events.append(event)
|
|
86
91
|
ct_events.append(
|
|
87
|
-
Migration(
|
|
92
|
+
Migration(
|
|
93
|
+
get_CT_state(event.state),
|
|
94
|
+
event.rate,
|
|
95
|
+
get_CT_state(event.target_state),
|
|
96
|
+
)
|
|
88
97
|
)
|
|
89
98
|
elif isinstance(event, Birth):
|
|
90
|
-
ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
|
|
91
99
|
ct_events.append(
|
|
92
|
-
BirthWithContactTracing(
|
|
100
|
+
BirthWithContactTracing(event.state, event.rate, event.child_state)
|
|
101
|
+
)
|
|
102
|
+
ct_events.append(
|
|
103
|
+
BirthWithContactTracing(
|
|
104
|
+
get_CT_state(event.state), event.rate, event.child_state
|
|
105
|
+
)
|
|
93
106
|
)
|
|
94
107
|
elif isinstance(event, Sampling):
|
|
95
108
|
if not event.removal:
|
|
@@ -98,7 +111,10 @@ def get_contact_tracing_events(
|
|
|
98
111
|
)
|
|
99
112
|
ct_events.append(
|
|
100
113
|
SamplingWithContactTracing(
|
|
101
|
-
state,
|
|
114
|
+
event.state,
|
|
115
|
+
event.rate,
|
|
116
|
+
max_notified_contacts,
|
|
117
|
+
notification_probability,
|
|
102
118
|
)
|
|
103
119
|
)
|
|
104
120
|
elif isinstance(event, Death):
|
|
@@ -115,7 +131,7 @@ def get_contact_tracing_events(
|
|
|
115
131
|
):
|
|
116
132
|
ct_events.append(
|
|
117
133
|
SamplingWithContactTracing(
|
|
118
|
-
|
|
134
|
+
get_CT_state(state),
|
|
119
135
|
sampling_rate_after_notification,
|
|
120
136
|
max_notified_contacts,
|
|
121
137
|
notification_probability,
|
|
@@ -7,7 +7,8 @@ from phylogenie.skyline import (
|
|
|
7
7
|
skyline_matrix,
|
|
8
8
|
skyline_vector,
|
|
9
9
|
)
|
|
10
|
-
from phylogenie.treesimulator.
|
|
10
|
+
from phylogenie.treesimulator.events.base import Event, EventType
|
|
11
|
+
from phylogenie.treesimulator.model import Model
|
|
11
12
|
|
|
12
13
|
INFECTIOUS_STATE = "I"
|
|
13
14
|
EXPOSED_STATE = "E"
|
|
@@ -15,11 +16,13 @@ SUPERSPREADER_STATE = "S"
|
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class Birth(Event):
|
|
19
|
+
type = EventType.BIRTH
|
|
20
|
+
|
|
18
21
|
def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
|
|
19
22
|
super().__init__(state, rate)
|
|
20
23
|
self.child_state = child_state
|
|
21
24
|
|
|
22
|
-
def apply(self, model: Model, time: float, rng: Generator)
|
|
25
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
23
26
|
individual = self.draw_individual(model, rng)
|
|
24
27
|
model.birth_from(individual, self.child_state, time)
|
|
25
28
|
|
|
@@ -28,7 +31,9 @@ class Birth(Event):
|
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
class Death(Event):
|
|
31
|
-
|
|
34
|
+
type = EventType.DEATH
|
|
35
|
+
|
|
36
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
32
37
|
individual = self.draw_individual(model, rng)
|
|
33
38
|
model.remove(individual, time)
|
|
34
39
|
|
|
@@ -37,11 +42,13 @@ class Death(Event):
|
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
class Migration(Event):
|
|
45
|
+
type = EventType.MIGRATION
|
|
46
|
+
|
|
40
47
|
def __init__(self, state: str, rate: SkylineParameterLike, target_state: str):
|
|
41
48
|
super().__init__(state, rate)
|
|
42
49
|
self.target_state = target_state
|
|
43
50
|
|
|
44
|
-
def apply(self, model: Model, time: float, rng: Generator)
|
|
51
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
45
52
|
individual = self.draw_individual(model, rng)
|
|
46
53
|
model.migrate(individual, self.target_state, time)
|
|
47
54
|
|
|
@@ -50,11 +57,13 @@ class Migration(Event):
|
|
|
50
57
|
|
|
51
58
|
|
|
52
59
|
class Sampling(Event):
|
|
60
|
+
type = EventType.SAMPLING
|
|
61
|
+
|
|
53
62
|
def __init__(self, state: str, rate: SkylineParameterLike, removal: bool):
|
|
54
63
|
super().__init__(state, rate)
|
|
55
64
|
self.removal = removal
|
|
56
65
|
|
|
57
|
-
def apply(self, model: Model, time: float, rng: Generator)
|
|
66
|
+
def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
|
|
58
67
|
individual = self.draw_individual(model, rng)
|
|
59
68
|
model.sample(individual, time, self.removal)
|
|
60
69
|
|
|
@@ -1,27 +1,21 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from
|
|
4
|
-
from typing import Type
|
|
3
|
+
from typing import Any
|
|
5
4
|
|
|
6
5
|
from numpy.random import Generator
|
|
7
6
|
|
|
8
7
|
from phylogenie.models import Distribution
|
|
9
8
|
from phylogenie.skyline import SkylineParameterLike
|
|
9
|
+
from phylogenie.treesimulator.events.base import Event, EventType
|
|
10
10
|
from phylogenie.treesimulator.events.contact_tracing import (
|
|
11
11
|
BirthWithContactTracing,
|
|
12
12
|
SamplingWithContactTracing,
|
|
13
13
|
)
|
|
14
|
-
from phylogenie.treesimulator.events.core import
|
|
15
|
-
Birth,
|
|
16
|
-
Death,
|
|
17
|
-
Event,
|
|
18
|
-
Migration,
|
|
19
|
-
Sampling,
|
|
20
|
-
)
|
|
14
|
+
from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
|
|
21
15
|
from phylogenie.treesimulator.model import Model
|
|
22
16
|
|
|
23
17
|
MUTATION_PREFIX = "MUT-"
|
|
24
|
-
|
|
18
|
+
NEXT_MUTATION_ID = "NEXT_MUTATION_ID"
|
|
25
19
|
|
|
26
20
|
|
|
27
21
|
def _get_mutation(state: str) -> str | None:
|
|
@@ -41,44 +35,42 @@ def get_mutation_id(node_name: str) -> int:
|
|
|
41
35
|
return 0
|
|
42
36
|
|
|
43
37
|
|
|
44
|
-
class TargetType(str, Enum):
|
|
45
|
-
BIRTH = "birth"
|
|
46
|
-
DEATH = "death"
|
|
47
|
-
MIGRATION = "migration"
|
|
48
|
-
SAMPLING = "sampling"
|
|
49
|
-
MUTATION = "mutation"
|
|
50
|
-
|
|
51
|
-
|
|
52
38
|
class Mutation(Event):
|
|
39
|
+
type = EventType.MUTATION
|
|
40
|
+
|
|
53
41
|
def __init__(
|
|
54
42
|
self,
|
|
55
43
|
state: str,
|
|
56
44
|
rate: SkylineParameterLike,
|
|
57
|
-
rate_scalers: dict[
|
|
45
|
+
rate_scalers: dict[EventType, Distribution],
|
|
58
46
|
):
|
|
59
47
|
super().__init__(state, rate)
|
|
60
48
|
self.rate_scalers = rate_scalers
|
|
61
49
|
|
|
62
|
-
def apply(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
model.context
|
|
66
|
-
|
|
50
|
+
def apply(
|
|
51
|
+
self, model: Model, events: list[Event], time: float, rng: Generator
|
|
52
|
+
) -> dict[str, Any]:
|
|
53
|
+
if NEXT_MUTATION_ID not in model.context:
|
|
54
|
+
model.context[NEXT_MUTATION_ID] = 0
|
|
55
|
+
model.context[NEXT_MUTATION_ID] += 1
|
|
56
|
+
mutation_id = model.context[NEXT_MUTATION_ID]
|
|
67
57
|
|
|
68
58
|
individual = self.draw_individual(model, rng)
|
|
69
59
|
model.migrate(individual, _get_mutated_state(mutation_id, self.state), time)
|
|
70
60
|
|
|
71
|
-
rate_scalers = {
|
|
61
|
+
rate_scalers: dict[EventType, float] = {
|
|
72
62
|
target_type: getattr(rng, rate_scaler.type)(**rate_scaler.args)
|
|
73
63
|
for target_type, rate_scaler in self.rate_scalers.items()
|
|
74
64
|
}
|
|
75
65
|
|
|
66
|
+
metadata: dict[str, Any] = {}
|
|
76
67
|
for event in [
|
|
77
68
|
deepcopy(e)
|
|
78
|
-
for e in
|
|
69
|
+
for e in events
|
|
79
70
|
if _get_mutation(self.state) == _get_mutation(e.state)
|
|
80
71
|
]:
|
|
81
72
|
event.state = _get_mutated_state(mutation_id, event.state)
|
|
73
|
+
|
|
82
74
|
if isinstance(event, Birth | BirthWithContactTracing):
|
|
83
75
|
event.child_state = _get_mutated_state(mutation_id, event.child_state)
|
|
84
76
|
elif isinstance(event, Migration):
|
|
@@ -87,27 +79,20 @@ class Mutation(Event):
|
|
|
87
79
|
event, Mutation | Death | Sampling | SamplingWithContactTracing
|
|
88
80
|
):
|
|
89
81
|
raise ValueError(
|
|
90
|
-
f"Mutation not
|
|
82
|
+
f"Mutation not implemented for event of type {type(event)}."
|
|
91
83
|
)
|
|
92
84
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
model.add_event(event)
|
|
85
|
+
if event.type in rate_scalers:
|
|
86
|
+
event.rate *= rate_scalers[event.type]
|
|
87
|
+
metadata[f"{MUTATION_PREFIX}{mutation_id}.{event.type}.rate.value"] = (
|
|
88
|
+
event.rate.value[0]
|
|
89
|
+
if len(event.rate.value) == 1
|
|
90
|
+
else list(event.rate.value)
|
|
91
|
+
)
|
|
102
92
|
|
|
103
|
-
|
|
104
|
-
return f"Mutation(state={self.state}, rate={self.rate})"
|
|
93
|
+
events.append(event)
|
|
105
94
|
|
|
95
|
+
return metadata
|
|
106
96
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
TargetType.DEATH: (Death,),
|
|
110
|
-
TargetType.MIGRATION: (Migration,),
|
|
111
|
-
TargetType.SAMPLING: (Sampling, SamplingWithContactTracing),
|
|
112
|
-
TargetType.MUTATION: (Mutation,),
|
|
113
|
-
}
|
|
97
|
+
def __repr__(self) -> str:
|
|
98
|
+
return f"Mutation(state={self.state}, rate={self.rate}, rate_scalers={self.rate_scalers})"
|
|
@@ -2,7 +2,7 @@ from collections.abc import Iterable
|
|
|
2
2
|
from enum import Enum
|
|
3
3
|
|
|
4
4
|
from phylogenie.tree import Tree
|
|
5
|
-
from phylogenie.treesimulator.events import get_mutation_id
|
|
5
|
+
from phylogenie.treesimulator.events.mutations import get_mutation_id
|
|
6
6
|
from phylogenie.treesimulator.model import get_node_state
|
|
7
7
|
from phylogenie.utils import (
|
|
8
8
|
get_node_depth_levels,
|
|
@@ -1,16 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from collections.abc import Iterable, Sequence
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
import joblib
|
|
6
7
|
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
7
9
|
from numpy.random import default_rng
|
|
8
10
|
from tqdm import tqdm
|
|
9
11
|
|
|
10
12
|
from phylogenie.io import dump_newick
|
|
11
13
|
from phylogenie.tree import Tree
|
|
14
|
+
from phylogenie.treesimulator.events import Event
|
|
12
15
|
from phylogenie.treesimulator.features import Feature, set_features
|
|
13
|
-
from phylogenie.treesimulator.model import
|
|
16
|
+
from phylogenie.treesimulator.model import Model
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def simulate_tree(
|
|
@@ -22,10 +25,7 @@ def simulate_tree(
|
|
|
22
25
|
sampling_probability_at_present: float = 0.0,
|
|
23
26
|
seed: int | None = None,
|
|
24
27
|
timeout: float = np.inf,
|
|
25
|
-
) -> Tree:
|
|
26
|
-
if max_time == np.inf and max_tips is None:
|
|
27
|
-
raise ValueError("Either max_time or max_tips must be specified.")
|
|
28
|
-
|
|
28
|
+
) -> tuple[Tree, dict[str, Any]]:
|
|
29
29
|
if max_time == np.inf and sampling_probability_at_present:
|
|
30
30
|
raise ValueError(
|
|
31
31
|
"sampling_probability_at_present cannot be set when max_time is infinite."
|
|
@@ -44,12 +44,16 @@ def simulate_tree(
|
|
|
44
44
|
rng = default_rng(seed)
|
|
45
45
|
start_clock = time.perf_counter()
|
|
46
46
|
while True:
|
|
47
|
-
model = Model(init_state
|
|
47
|
+
model = Model(init_state)
|
|
48
|
+
metadata: dict[str, Any] = {}
|
|
49
|
+
run_events = list(events)
|
|
48
50
|
current_time = 0.0
|
|
49
51
|
change_times = sorted(set(t for e in events for t in e.rate.change_times))
|
|
50
52
|
next_change_time = change_times.pop(0) if change_times else np.inf
|
|
53
|
+
|
|
51
54
|
if max_time == np.inf:
|
|
52
|
-
|
|
55
|
+
if max_tips is None:
|
|
56
|
+
raise ValueError("Either max_time or max_tips must be specified.")
|
|
53
57
|
target_n_tips = rng.integers(min_tips, max_tips + 1)
|
|
54
58
|
else:
|
|
55
59
|
target_n_tips = None
|
|
@@ -58,13 +62,12 @@ def simulate_tree(
|
|
|
58
62
|
if time.perf_counter() - start_clock > timeout:
|
|
59
63
|
raise TimeoutError("Simulation timed out.")
|
|
60
64
|
|
|
61
|
-
|
|
62
|
-
rates = [e.get_propensity(model, current_time) for e in events]
|
|
65
|
+
rates = [e.get_propensity(model, current_time) for e in run_events]
|
|
63
66
|
|
|
64
|
-
instantaneous_events = [e for e, r in zip(
|
|
67
|
+
instantaneous_events = [e for e, r in zip(run_events, rates) if r == np.inf]
|
|
65
68
|
if instantaneous_events:
|
|
66
69
|
event = instantaneous_events[rng.integers(len(instantaneous_events))]
|
|
67
|
-
event.apply(model, current_time, rng)
|
|
70
|
+
event.apply(model, run_events, current_time, rng)
|
|
68
71
|
continue
|
|
69
72
|
|
|
70
73
|
if (
|
|
@@ -87,7 +90,10 @@ def simulate_tree(
|
|
|
87
90
|
current_time += time_step
|
|
88
91
|
|
|
89
92
|
event_idx = np.searchsorted(np.cumsum(rates) / sum(rates), rng.random())
|
|
90
|
-
|
|
93
|
+
event = run_events[int(event_idx)]
|
|
94
|
+
event_metadata = event.apply(model, run_events, current_time, rng)
|
|
95
|
+
if event_metadata is not None:
|
|
96
|
+
metadata.update(event_metadata)
|
|
91
97
|
|
|
92
98
|
for individual in model.get_population():
|
|
93
99
|
if rng.random() < sampling_probability_at_present:
|
|
@@ -96,7 +102,7 @@ def simulate_tree(
|
|
|
96
102
|
if min_tips <= model.n_sampled and (
|
|
97
103
|
max_tips is None or model.n_sampled <= max_tips
|
|
98
104
|
):
|
|
99
|
-
return model.get_sampled_tree()
|
|
105
|
+
return (model.get_sampled_tree(), metadata)
|
|
100
106
|
|
|
101
107
|
|
|
102
108
|
def generate_trees(
|
|
@@ -112,11 +118,11 @@ def generate_trees(
|
|
|
112
118
|
seed: int | None = None,
|
|
113
119
|
n_jobs: int = -1,
|
|
114
120
|
timeout: float = np.inf,
|
|
115
|
-
) ->
|
|
116
|
-
def _simulate_tree(seed: int) -> Tree:
|
|
121
|
+
) -> pd.DataFrame:
|
|
122
|
+
def _simulate_tree(seed: int) -> tuple[Tree, dict[str, Any]]:
|
|
117
123
|
while True:
|
|
118
124
|
try:
|
|
119
|
-
tree = simulate_tree(
|
|
125
|
+
tree, metadata = simulate_tree(
|
|
120
126
|
events=events,
|
|
121
127
|
min_tips=min_tips,
|
|
122
128
|
max_tips=max_tips,
|
|
@@ -128,7 +134,7 @@ def generate_trees(
|
|
|
128
134
|
)
|
|
129
135
|
if node_features is not None:
|
|
130
136
|
set_features(tree, node_features)
|
|
131
|
-
return tree
|
|
137
|
+
return (tree, metadata)
|
|
132
138
|
except TimeoutError:
|
|
133
139
|
print("Simulation timed out, retrying with a different seed...")
|
|
134
140
|
seed += 1
|
|
@@ -142,7 +148,11 @@ def generate_trees(
|
|
|
142
148
|
joblib.delayed(_simulate_tree)(seed=int(rng.integers(2**32)))
|
|
143
149
|
for _ in range(n_trees)
|
|
144
150
|
)
|
|
145
|
-
|
|
151
|
+
|
|
152
|
+
df: list[dict[str, Any]] = []
|
|
153
|
+
for i, (tree, metadata) in tqdm(
|
|
146
154
|
enumerate(jobs), total=n_trees, desc=f"Generating trees in {output_dir}..."
|
|
147
155
|
):
|
|
156
|
+
df.append({"file_id": i} | metadata)
|
|
148
157
|
dump_newick(tree, os.path.join(output_dir, f"{i}.nwk"))
|
|
158
|
+
return pd.DataFrame(df)
|
|
@@ -1,13 +1,7 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
1
|
from collections import defaultdict
|
|
3
|
-
from collections.abc import Sequence
|
|
4
2
|
from dataclasses import dataclass
|
|
5
3
|
from typing import Any
|
|
6
4
|
|
|
7
|
-
import numpy as np
|
|
8
|
-
from numpy.random import Generator
|
|
9
|
-
|
|
10
|
-
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
11
5
|
from phylogenie.tree import Tree
|
|
12
6
|
|
|
13
7
|
|
|
@@ -18,25 +12,6 @@ class Individual:
|
|
|
18
12
|
state: str
|
|
19
13
|
|
|
20
14
|
|
|
21
|
-
class Event(ABC):
|
|
22
|
-
def __init__(self, state: str, rate: SkylineParameterLike):
|
|
23
|
-
self.state = state
|
|
24
|
-
self.rate = skyline_parameter(rate)
|
|
25
|
-
|
|
26
|
-
def draw_individual(self, model: "Model", rng: Generator) -> int:
|
|
27
|
-
return rng.choice(model.get_population(self.state))
|
|
28
|
-
|
|
29
|
-
def get_propensity(self, model: "Model", time: float) -> float:
|
|
30
|
-
n_individuals = model.count_individuals(self.state)
|
|
31
|
-
rate = self.rate.get_value_at_time(time)
|
|
32
|
-
if rate == np.inf and not n_individuals:
|
|
33
|
-
return 0
|
|
34
|
-
return rate * n_individuals
|
|
35
|
-
|
|
36
|
-
@abstractmethod
|
|
37
|
-
def apply(self, model: "Model", time: float, rng: Generator) -> None: ...
|
|
38
|
-
|
|
39
|
-
|
|
40
15
|
def _get_node_name(node_id: int, state: str) -> str:
|
|
41
16
|
return f"{node_id}|{state}"
|
|
42
17
|
|
|
@@ -51,27 +26,19 @@ def get_node_state(node_name: str) -> str:
|
|
|
51
26
|
|
|
52
27
|
|
|
53
28
|
class Model:
|
|
54
|
-
def __init__(self, init_state: str
|
|
29
|
+
def __init__(self, init_state: str):
|
|
55
30
|
self._next_node_id = 0
|
|
56
31
|
self._next_individual_id = 0
|
|
57
32
|
self._population: dict[int, Individual] = {}
|
|
58
33
|
self._states: dict[str, set[int]] = defaultdict(set)
|
|
59
34
|
self._sampled: set[str] = set()
|
|
60
35
|
self._tree = self._get_new_individual(init_state).node
|
|
61
|
-
self._events = list(events)
|
|
62
36
|
self.context: dict[str, Any] = {}
|
|
63
37
|
|
|
64
38
|
@property
|
|
65
39
|
def n_sampled(self) -> int:
|
|
66
40
|
return len(self._sampled)
|
|
67
41
|
|
|
68
|
-
@property
|
|
69
|
-
def events(self) -> tuple[Event, ...]:
|
|
70
|
-
return tuple(self._events)
|
|
71
|
-
|
|
72
|
-
def add_event(self, event: Event) -> None:
|
|
73
|
-
self._events.append(event)
|
|
74
|
-
|
|
75
42
|
def _get_new_node(self, state: str) -> Tree:
|
|
76
43
|
self._next_node_id += 1
|
|
77
44
|
node = Tree(_get_node_name(self._next_node_id, state))
|
|
@@ -15,6 +15,6 @@ OneOrMany2DScalars = OneOrMany2D[Scalar]
|
|
|
15
15
|
Many2DScalars = Many2D[Scalar]
|
|
16
16
|
Many3DScalars = Many3D[Scalar]
|
|
17
17
|
|
|
18
|
-
Vector1D =
|
|
19
|
-
Vector2D =
|
|
20
|
-
Vector3D =
|
|
18
|
+
Vector1D = tuple[Scalar, ...]
|
|
19
|
+
Vector2D = tuple[Vector1D, ...]
|
|
20
|
+
Vector3D = tuple[Vector2D, ...]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|