phylogenie 2.0.14__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +22 -7
- phylogenie/generators/__init__.py +0 -8
- phylogenie/generators/configs.py +11 -12
- phylogenie/generators/dataset.py +10 -7
- phylogenie/generators/factories.py +9 -0
- phylogenie/generators/trees.py +88 -61
- phylogenie/io.py +27 -12
- phylogenie/tree.py +38 -12
- phylogenie/treesimulator/__init__.py +18 -1
- phylogenie/treesimulator/events/__init__.py +39 -0
- phylogenie/treesimulator/events/contact_tracing.py +125 -0
- phylogenie/treesimulator/{events.py → events/core.py} +73 -125
- phylogenie/treesimulator/events/mutations.py +105 -0
- phylogenie/treesimulator/gillespie.py +77 -40
- phylogenie/treesimulator/model.py +57 -56
- phylogenie/typings.py +0 -1
- phylogenie/utils.py +17 -0
- {phylogenie-2.0.14.dist-info → phylogenie-2.1.1.dist-info}/METADATA +1 -2
- phylogenie-2.1.1.dist-info/RECORD +32 -0
- phylogenie-2.0.14.dist-info/RECORD +0 -28
- {phylogenie-2.0.14.dist-info → phylogenie-2.1.1.dist-info}/LICENSE.txt +0 -0
- {phylogenie-2.0.14.dist-info → phylogenie-2.1.1.dist-info}/WHEEL +0 -0
- {phylogenie-2.0.14.dist-info → phylogenie-2.1.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from phylogenie.treesimulator.events.contact_tracing import (
|
|
2
|
+
BirthWithContactTracing,
|
|
3
|
+
SamplingWithContactTracing,
|
|
4
|
+
get_contact_tracing_events,
|
|
5
|
+
)
|
|
6
|
+
from phylogenie.treesimulator.events.core import (
|
|
7
|
+
Birth,
|
|
8
|
+
Death,
|
|
9
|
+
Event,
|
|
10
|
+
Migration,
|
|
11
|
+
Sampling,
|
|
12
|
+
get_BD_events,
|
|
13
|
+
get_BDEI_events,
|
|
14
|
+
get_BDSS_events,
|
|
15
|
+
get_canonical_events,
|
|
16
|
+
get_epidemiological_events,
|
|
17
|
+
get_FBD_events,
|
|
18
|
+
)
|
|
19
|
+
from phylogenie.treesimulator.events.mutations import Mutation
|
|
20
|
+
from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"Birth",
|
|
24
|
+
"BirthWithContactTracing",
|
|
25
|
+
"Death",
|
|
26
|
+
"Event",
|
|
27
|
+
"Migration",
|
|
28
|
+
"Mutation",
|
|
29
|
+
"Sampling",
|
|
30
|
+
"SamplingWithContactTracing",
|
|
31
|
+
"MutationTargetType",
|
|
32
|
+
"get_BD_events",
|
|
33
|
+
"get_BDEI_events",
|
|
34
|
+
"get_BDSS_events",
|
|
35
|
+
"get_canonical_events",
|
|
36
|
+
"get_contact_tracing_events",
|
|
37
|
+
"get_epidemiological_events",
|
|
38
|
+
"get_FBD_events",
|
|
39
|
+
]
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.random import Generator
|
|
6
|
+
|
|
7
|
+
from phylogenie.skyline import SkylineParameterLike, skyline_parameter
|
|
8
|
+
from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
|
|
9
|
+
from phylogenie.treesimulator.model import Event, Model
|
|
10
|
+
|
|
11
|
+
CT_POSTFIX = "-CT"
|
|
12
|
+
CONTACTS_KEY = "CONTACTS"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_CT_state(state: str) -> str:
|
|
16
|
+
return f"{state}{CT_POSTFIX}"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_CT_state(state: str) -> bool:
|
|
20
|
+
return state.endswith(CT_POSTFIX)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BirthWithContactTracing(Event):
|
|
24
|
+
def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
|
|
25
|
+
super().__init__(state, rate)
|
|
26
|
+
self.child_state = child_state
|
|
27
|
+
|
|
28
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
29
|
+
individual = self.draw_individual(model, rng)
|
|
30
|
+
new_individual = model.birth_from(individual, self.child_state, time)
|
|
31
|
+
if CONTACTS_KEY not in model.context:
|
|
32
|
+
model.context[CONTACTS_KEY] = defaultdict(list)
|
|
33
|
+
model.context[CONTACTS_KEY][individual].append(new_individual)
|
|
34
|
+
model.context[CONTACTS_KEY][new_individual].append(individual)
|
|
35
|
+
|
|
36
|
+
def __repr__(self) -> str:
|
|
37
|
+
return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SamplingWithContactTracing(Event):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
state: str,
|
|
44
|
+
rate: SkylineParameterLike,
|
|
45
|
+
max_notified_contacts: int,
|
|
46
|
+
notification_probability: SkylineParameterLike,
|
|
47
|
+
):
|
|
48
|
+
super().__init__(state, rate)
|
|
49
|
+
self.max_notified_contacts = max_notified_contacts
|
|
50
|
+
self.notification_probability = skyline_parameter(notification_probability)
|
|
51
|
+
|
|
52
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
53
|
+
individual = self.draw_individual(model, rng)
|
|
54
|
+
model.sample(individual, time, True)
|
|
55
|
+
population = model.get_population()
|
|
56
|
+
if CONTACTS_KEY not in model.context:
|
|
57
|
+
return
|
|
58
|
+
contacts = model.context[CONTACTS_KEY][individual]
|
|
59
|
+
for contact in contacts[-self.max_notified_contacts :]:
|
|
60
|
+
if contact in population:
|
|
61
|
+
state = model.get_state(contact)
|
|
62
|
+
p = self.notification_probability.get_value_at_time(time)
|
|
63
|
+
if not _is_CT_state(state) and rng.random() < p:
|
|
64
|
+
model.migrate(contact, _get_CT_state(state), time)
|
|
65
|
+
|
|
66
|
+
def __repr__(self) -> str:
|
|
67
|
+
return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_contact_tracing_events(
|
|
71
|
+
events: Sequence[Event],
|
|
72
|
+
max_notified_contacts: int = 1,
|
|
73
|
+
notification_probability: SkylineParameterLike = 1,
|
|
74
|
+
sampling_rate_after_notification: SkylineParameterLike = np.inf,
|
|
75
|
+
samplable_states_after_notification: list[str] | None = None,
|
|
76
|
+
) -> list[Event]:
|
|
77
|
+
ct_events: list[Event] = []
|
|
78
|
+
notification_probability = skyline_parameter(notification_probability)
|
|
79
|
+
sampling_rate_after_notification = skyline_parameter(
|
|
80
|
+
sampling_rate_after_notification
|
|
81
|
+
)
|
|
82
|
+
for event in events:
|
|
83
|
+
state, rate = event.state, event.rate
|
|
84
|
+
if isinstance(event, Migration):
|
|
85
|
+
ct_events.append(event)
|
|
86
|
+
ct_events.append(
|
|
87
|
+
Migration(_get_CT_state(state), rate, _get_CT_state(event.target_state))
|
|
88
|
+
)
|
|
89
|
+
elif isinstance(event, Birth):
|
|
90
|
+
ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
|
|
91
|
+
ct_events.append(
|
|
92
|
+
BirthWithContactTracing(_get_CT_state(state), rate, event.child_state)
|
|
93
|
+
)
|
|
94
|
+
elif isinstance(event, Sampling):
|
|
95
|
+
if not event.removal:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Contact tracing requires removal to be set for all sampling events."
|
|
98
|
+
)
|
|
99
|
+
ct_events.append(
|
|
100
|
+
SamplingWithContactTracing(
|
|
101
|
+
state, rate, max_notified_contacts, notification_probability
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
elif isinstance(event, Death):
|
|
105
|
+
ct_events.append(event)
|
|
106
|
+
else:
|
|
107
|
+
raise NotImplementedError(
|
|
108
|
+
f"Unsupported event type {type(event)} for contact tracing."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
for state in (
|
|
112
|
+
samplable_states_after_notification
|
|
113
|
+
if samplable_states_after_notification is not None
|
|
114
|
+
else {e.state for e in events}
|
|
115
|
+
):
|
|
116
|
+
ct_events.append(
|
|
117
|
+
SamplingWithContactTracing(
|
|
118
|
+
_get_CT_state(state),
|
|
119
|
+
sampling_rate_after_notification,
|
|
120
|
+
max_notified_contacts,
|
|
121
|
+
notification_probability,
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return ct_events
|
|
@@ -1,80 +1,73 @@
|
|
|
1
|
-
from
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
1
|
+
from numpy.random import Generator
|
|
6
2
|
|
|
7
3
|
from phylogenie.skyline import (
|
|
8
4
|
SkylineMatrixCoercible,
|
|
9
|
-
SkylineParameter,
|
|
10
5
|
SkylineParameterLike,
|
|
11
6
|
SkylineVectorCoercible,
|
|
12
7
|
skyline_matrix,
|
|
13
|
-
skyline_parameter,
|
|
14
8
|
skyline_vector,
|
|
15
9
|
)
|
|
16
|
-
from phylogenie.treesimulator.model import
|
|
10
|
+
from phylogenie.treesimulator.model import Event, Model
|
|
17
11
|
|
|
18
12
|
INFECTIOUS_STATE = "I"
|
|
19
13
|
EXPOSED_STATE = "E"
|
|
20
14
|
SUPERSPREADER_STATE = "S"
|
|
21
15
|
|
|
22
16
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
17
|
+
class Birth(Event):
|
|
18
|
+
def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
|
|
19
|
+
super().__init__(state, rate)
|
|
20
|
+
self.child_state = child_state
|
|
27
21
|
|
|
28
|
-
def
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
if rate == np.inf and not n_individuals:
|
|
32
|
-
return 0
|
|
33
|
-
return rate * n_individuals
|
|
22
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
23
|
+
individual = self.draw_individual(model, rng)
|
|
24
|
+
model.birth_from(individual, self.child_state, time)
|
|
34
25
|
|
|
35
|
-
|
|
36
|
-
|
|
26
|
+
def __repr__(self) -> str:
|
|
27
|
+
return f"Birth(state={self.state}, rate={self.rate}, child_state={self.child_state})"
|
|
37
28
|
|
|
38
29
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
30
|
+
class Death(Event):
|
|
31
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
32
|
+
individual = self.draw_individual(model, rng)
|
|
33
|
+
model.remove(individual, time)
|
|
42
34
|
|
|
43
|
-
def
|
|
44
|
-
|
|
45
|
-
model.birth_from(individual, self.child_state, time)
|
|
35
|
+
def __repr__(self) -> str:
|
|
36
|
+
return f"Death(state={self.state}, rate={self.rate})"
|
|
46
37
|
|
|
47
38
|
|
|
48
|
-
class
|
|
49
|
-
def
|
|
50
|
-
|
|
51
|
-
|
|
39
|
+
class Migration(Event):
|
|
40
|
+
def __init__(self, state: str, rate: SkylineParameterLike, target_state: str):
|
|
41
|
+
super().__init__(state, rate)
|
|
42
|
+
self.target_state = target_state
|
|
52
43
|
|
|
44
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
45
|
+
individual = self.draw_individual(model, rng)
|
|
46
|
+
model.migrate(individual, self.target_state, time)
|
|
53
47
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
target_state: str
|
|
48
|
+
def __repr__(self) -> str:
|
|
49
|
+
return f"Migration(state={self.state}, rate={self.rate}, target_state={self.target_state})"
|
|
57
50
|
|
|
58
|
-
def apply(self, model: Model, time: float) -> None:
|
|
59
|
-
individual = model.get_random_individual(self.state)
|
|
60
|
-
model.migrate(individual, self.target_state, time)
|
|
61
51
|
|
|
52
|
+
class Sampling(Event):
|
|
53
|
+
def __init__(self, state: str, rate: SkylineParameterLike, removal: bool):
|
|
54
|
+
super().__init__(state, rate)
|
|
55
|
+
self.removal = removal
|
|
62
56
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
57
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
58
|
+
individual = self.draw_individual(model, rng)
|
|
59
|
+
model.sample(individual, time, self.removal)
|
|
66
60
|
|
|
67
|
-
def
|
|
68
|
-
|
|
69
|
-
model.sample(individual, time, self.removal_probability.get_value_at_time(time))
|
|
61
|
+
def __repr__(self) -> str:
|
|
62
|
+
return f"Sampling(state={self.state}, rate={self.rate}, removal={self.removal})"
|
|
70
63
|
|
|
71
64
|
|
|
72
65
|
def get_canonical_events(
|
|
73
|
-
states:
|
|
66
|
+
states: list[str],
|
|
74
67
|
sampling_rates: SkylineVectorCoercible,
|
|
68
|
+
remove_after_sampling: bool,
|
|
75
69
|
birth_rates: SkylineVectorCoercible = 0,
|
|
76
70
|
death_rates: SkylineVectorCoercible = 0,
|
|
77
|
-
removal_probabilities: SkylineVectorCoercible = 0,
|
|
78
71
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
79
72
|
birth_rates_among_states: SkylineMatrixCoercible | None = None,
|
|
80
73
|
) -> list[Event]:
|
|
@@ -83,102 +76,94 @@ def get_canonical_events(
|
|
|
83
76
|
birth_rates = skyline_vector(birth_rates, N)
|
|
84
77
|
death_rates = skyline_vector(death_rates, N)
|
|
85
78
|
sampling_rates = skyline_vector(sampling_rates, N)
|
|
86
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
87
79
|
|
|
88
80
|
events: list[Event] = []
|
|
89
81
|
for i, state in enumerate(states):
|
|
90
|
-
events.append(
|
|
91
|
-
events.append(
|
|
92
|
-
events.append(
|
|
82
|
+
events.append(Birth(state, birth_rates[i], state))
|
|
83
|
+
events.append(Death(state, death_rates[i]))
|
|
84
|
+
events.append(Sampling(state, sampling_rates[i], remove_after_sampling))
|
|
93
85
|
|
|
94
86
|
if migration_rates is not None:
|
|
95
87
|
migration_rates = skyline_matrix(migration_rates, N, N - 1)
|
|
96
88
|
for i, state in enumerate(states):
|
|
97
89
|
for j, other_state in enumerate([s for s in states if s != state]):
|
|
98
|
-
events.append(
|
|
90
|
+
events.append(Migration(state, migration_rates[i, j], other_state))
|
|
99
91
|
|
|
100
92
|
if birth_rates_among_states is not None:
|
|
101
93
|
birth_rates_among_states = skyline_matrix(birth_rates_among_states, N, N - 1)
|
|
102
94
|
for i, state in enumerate(states):
|
|
103
95
|
for j, other_state in enumerate([s for s in states if s != state]):
|
|
104
|
-
events.append(
|
|
105
|
-
BirthEvent(birth_rates_among_states[i, j], state, other_state)
|
|
106
|
-
)
|
|
96
|
+
events.append(Birth(state, birth_rates_among_states[i, j], other_state))
|
|
107
97
|
|
|
108
98
|
return [event for event in events if event.rate]
|
|
109
99
|
|
|
110
100
|
|
|
111
|
-
def
|
|
112
|
-
states:
|
|
101
|
+
def get_FBD_events(
|
|
102
|
+
states: list[str],
|
|
113
103
|
sampling_proportions: SkylineVectorCoercible = 1,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
removal_probabilities: SkylineVectorCoercible = 1,
|
|
104
|
+
diversification: SkylineVectorCoercible = 0,
|
|
105
|
+
turnover: SkylineVectorCoercible = 0,
|
|
117
106
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
118
|
-
|
|
107
|
+
diversification_between_states: SkylineMatrixCoercible | None = None,
|
|
119
108
|
) -> list[Event]:
|
|
120
109
|
N = len(states)
|
|
121
110
|
|
|
122
|
-
|
|
123
|
-
|
|
111
|
+
diversification = skyline_vector(diversification, N)
|
|
112
|
+
turnover = skyline_vector(turnover, N)
|
|
124
113
|
sampling_proportions = skyline_vector(sampling_proportions, N)
|
|
125
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
126
114
|
|
|
127
|
-
birth_rates =
|
|
128
|
-
|
|
129
|
-
|
|
115
|
+
birth_rates = diversification / (1 - turnover)
|
|
116
|
+
death_rates = turnover * birth_rates
|
|
117
|
+
sampling_rates = sampling_proportions * death_rates
|
|
130
118
|
birth_rates_among_states = (
|
|
131
|
-
(
|
|
132
|
-
|
|
133
|
-
* become_uninfectious_rates
|
|
134
|
-
)
|
|
135
|
-
if reproduction_numbers_among_states is not None
|
|
119
|
+
(skyline_matrix(diversification_between_states, N, N - 1) + death_rates)
|
|
120
|
+
if diversification_between_states is not None
|
|
136
121
|
else None
|
|
137
122
|
)
|
|
138
123
|
|
|
139
124
|
return get_canonical_events(
|
|
140
125
|
states=states,
|
|
126
|
+
sampling_rates=sampling_rates,
|
|
127
|
+
remove_after_sampling=False,
|
|
141
128
|
birth_rates=birth_rates,
|
|
142
129
|
death_rates=death_rates,
|
|
143
|
-
sampling_rates=sampling_rates,
|
|
144
|
-
removal_probabilities=removal_probabilities,
|
|
145
130
|
migration_rates=migration_rates,
|
|
146
131
|
birth_rates_among_states=birth_rates_among_states,
|
|
147
132
|
)
|
|
148
133
|
|
|
149
134
|
|
|
150
|
-
def
|
|
151
|
-
states:
|
|
152
|
-
diversification: SkylineVectorCoercible = 0,
|
|
153
|
-
turnover: SkylineVectorCoercible = 0,
|
|
135
|
+
def get_epidemiological_events(
|
|
136
|
+
states: list[str],
|
|
154
137
|
sampling_proportions: SkylineVectorCoercible = 1,
|
|
155
|
-
|
|
138
|
+
reproduction_numbers: SkylineVectorCoercible = 0,
|
|
139
|
+
become_uninfectious_rates: SkylineVectorCoercible = 0,
|
|
156
140
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
157
|
-
|
|
141
|
+
reproduction_numbers_among_states: SkylineMatrixCoercible | None = None,
|
|
158
142
|
) -> list[Event]:
|
|
159
143
|
N = len(states)
|
|
160
144
|
|
|
161
|
-
|
|
162
|
-
|
|
145
|
+
reproduction_numbers = skyline_vector(reproduction_numbers, N)
|
|
146
|
+
become_uninfectious_rates = skyline_vector(become_uninfectious_rates, N)
|
|
163
147
|
sampling_proportions = skyline_vector(sampling_proportions, N)
|
|
164
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
165
148
|
|
|
166
|
-
birth_rates =
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
sampling_rates = sampling_proportions * death_rates / sampling_rates_dividend
|
|
149
|
+
birth_rates = reproduction_numbers * become_uninfectious_rates
|
|
150
|
+
sampling_rates = become_uninfectious_rates * sampling_proportions
|
|
151
|
+
death_rates = become_uninfectious_rates - sampling_rates
|
|
170
152
|
birth_rates_among_states = (
|
|
171
|
-
(
|
|
172
|
-
|
|
153
|
+
(
|
|
154
|
+
skyline_matrix(reproduction_numbers_among_states, N, N - 1)
|
|
155
|
+
* become_uninfectious_rates
|
|
156
|
+
)
|
|
157
|
+
if reproduction_numbers_among_states is not None
|
|
173
158
|
else None
|
|
174
159
|
)
|
|
175
160
|
|
|
176
161
|
return get_canonical_events(
|
|
177
162
|
states=states,
|
|
163
|
+
sampling_rates=sampling_rates,
|
|
164
|
+
remove_after_sampling=True,
|
|
178
165
|
birth_rates=birth_rates,
|
|
179
166
|
death_rates=death_rates,
|
|
180
|
-
sampling_rates=sampling_rates,
|
|
181
|
-
removal_probabilities=removal_probabilities,
|
|
182
167
|
migration_rates=migration_rates,
|
|
183
168
|
birth_rates_among_states=birth_rates_among_states,
|
|
184
169
|
)
|
|
@@ -232,40 +217,3 @@ def get_BDSS_events(
|
|
|
232
217
|
become_uninfectious_rates=1 / infectious_period,
|
|
233
218
|
sampling_proportions=sampling_proportion,
|
|
234
219
|
)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
def get_contact_tracing_events(
|
|
238
|
-
events: Sequence[Event],
|
|
239
|
-
samplable_states_after_notification: Sequence[str] | None = None,
|
|
240
|
-
sampling_rate_after_notification: SkylineParameterLike = np.inf,
|
|
241
|
-
contacts_removal_probability: SkylineParameterLike = 1,
|
|
242
|
-
) -> list[Event]:
|
|
243
|
-
ct_events = list(events)
|
|
244
|
-
for event in events:
|
|
245
|
-
if isinstance(event, MigrationEvent):
|
|
246
|
-
ct_events.append(
|
|
247
|
-
MigrationEvent(
|
|
248
|
-
event.rate,
|
|
249
|
-
get_CT_state(event.state),
|
|
250
|
-
get_CT_state(event.target_state),
|
|
251
|
-
)
|
|
252
|
-
)
|
|
253
|
-
elif isinstance(event, BirthEvent):
|
|
254
|
-
ct_events.append(
|
|
255
|
-
BirthEvent(event.rate, get_CT_state(event.state), event.child_state)
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
for state in (
|
|
259
|
-
samplable_states_after_notification
|
|
260
|
-
if samplable_states_after_notification is not None
|
|
261
|
-
else [e.state for e in events]
|
|
262
|
-
):
|
|
263
|
-
ct_events.append(
|
|
264
|
-
SamplingEvent(
|
|
265
|
-
skyline_parameter(sampling_rate_after_notification),
|
|
266
|
-
get_CT_state(state),
|
|
267
|
-
skyline_parameter(contacts_removal_probability),
|
|
268
|
-
)
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
return ct_events
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Type
|
|
4
|
+
|
|
5
|
+
from numpy.random import Generator
|
|
6
|
+
|
|
7
|
+
from phylogenie.skyline import SkylineParameterLike
|
|
8
|
+
from phylogenie.treesimulator.events.contact_tracing import (
|
|
9
|
+
BirthWithContactTracing,
|
|
10
|
+
SamplingWithContactTracing,
|
|
11
|
+
)
|
|
12
|
+
from phylogenie.treesimulator.events.core import (
|
|
13
|
+
Birth,
|
|
14
|
+
Death,
|
|
15
|
+
Event,
|
|
16
|
+
Migration,
|
|
17
|
+
Sampling,
|
|
18
|
+
)
|
|
19
|
+
from phylogenie.treesimulator.model import Model
|
|
20
|
+
from phylogenie.utils import Distribution
|
|
21
|
+
|
|
22
|
+
MUTATION_PREFIX = "MUT-"
|
|
23
|
+
MUTATIONS_KEY = "MUTATIONS"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_mutation(state: str) -> str | None:
|
|
27
|
+
return state.split(".")[0] if state.startswith(MUTATION_PREFIX) else None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_mutated_state(mutation_id: int, state: str) -> str:
|
|
31
|
+
if state.startswith(MUTATION_PREFIX):
|
|
32
|
+
state = state.split(".")[1]
|
|
33
|
+
return f"{MUTATION_PREFIX}{mutation_id}.{state}"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TargetType(str, Enum):
|
|
37
|
+
BIRTH = "birth"
|
|
38
|
+
DEATH = "death"
|
|
39
|
+
MIGRATION = "migration"
|
|
40
|
+
SAMPLING = "sampling"
|
|
41
|
+
MUTATION = "mutation"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Mutation(Event):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
state: str,
|
|
48
|
+
rate: SkylineParameterLike,
|
|
49
|
+
rate_scalers: dict[TargetType, Distribution],
|
|
50
|
+
):
|
|
51
|
+
super().__init__(state, rate)
|
|
52
|
+
self.rate_scalers = rate_scalers
|
|
53
|
+
|
|
54
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
55
|
+
if MUTATIONS_KEY not in model.context:
|
|
56
|
+
model.context[MUTATIONS_KEY] = 0
|
|
57
|
+
model.context[MUTATIONS_KEY] += 1
|
|
58
|
+
mutation_id = model.context[MUTATIONS_KEY]
|
|
59
|
+
|
|
60
|
+
individual = self.draw_individual(model, rng)
|
|
61
|
+
model.migrate(individual, _get_mutated_state(mutation_id, self.state), time)
|
|
62
|
+
|
|
63
|
+
rate_scalers = {
|
|
64
|
+
target_type: getattr(rng, rate_scaler.type)(**rate_scaler.args)
|
|
65
|
+
for target_type, rate_scaler in self.rate_scalers.items()
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
for event in [
|
|
69
|
+
deepcopy(e)
|
|
70
|
+
for e in model.events
|
|
71
|
+
if _get_mutation(self.state) == _get_mutation(e.state)
|
|
72
|
+
]:
|
|
73
|
+
event.state = _get_mutated_state(mutation_id, event.state)
|
|
74
|
+
if isinstance(event, Birth | BirthWithContactTracing):
|
|
75
|
+
event.child_state = _get_mutated_state(mutation_id, event.child_state)
|
|
76
|
+
elif isinstance(event, Migration):
|
|
77
|
+
event.target_state = _get_mutated_state(mutation_id, event.target_state)
|
|
78
|
+
elif not isinstance(
|
|
79
|
+
event, Mutation | Death | Sampling | SamplingWithContactTracing
|
|
80
|
+
):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Mutation not defined for event of type {type(event)}."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
for target_type, rate_scaler in rate_scalers.items():
|
|
86
|
+
if target_type not in TARGETS:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Unsupported target type {target_type} for mutation."
|
|
89
|
+
)
|
|
90
|
+
if isinstance(event, TARGETS[target_type]):
|
|
91
|
+
event.rate *= rate_scaler
|
|
92
|
+
|
|
93
|
+
model.add_event(event)
|
|
94
|
+
|
|
95
|
+
def __repr__(self) -> str:
|
|
96
|
+
return f"Mutation(state={self.state}, rate={self.rate}, rate_scalers={self.rate_scalers})"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
TARGETS: dict[TargetType, tuple[Type[Event], ...]] = {
|
|
100
|
+
TargetType.BIRTH: (Birth, BirthWithContactTracing),
|
|
101
|
+
TargetType.DEATH: (Death,),
|
|
102
|
+
TargetType.MIGRATION: (Migration,),
|
|
103
|
+
TargetType.SAMPLING: (Sampling, SamplingWithContactTracing),
|
|
104
|
+
TargetType.MUTATION: (Mutation,),
|
|
105
|
+
}
|