phylogenie 2.0.14__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,39 @@
1
+ from phylogenie.treesimulator.events.contact_tracing import (
2
+ BirthWithContactTracing,
3
+ SamplingWithContactTracing,
4
+ get_contact_tracing_events,
5
+ )
6
+ from phylogenie.treesimulator.events.core import (
7
+ Birth,
8
+ Death,
9
+ Event,
10
+ Migration,
11
+ Sampling,
12
+ get_BD_events,
13
+ get_BDEI_events,
14
+ get_BDSS_events,
15
+ get_canonical_events,
16
+ get_epidemiological_events,
17
+ get_FBD_events,
18
+ )
19
+ from phylogenie.treesimulator.events.mutations import Mutation
20
+ from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
21
+
22
+ __all__ = [
23
+ "Birth",
24
+ "BirthWithContactTracing",
25
+ "Death",
26
+ "Event",
27
+ "Migration",
28
+ "Mutation",
29
+ "Sampling",
30
+ "SamplingWithContactTracing",
31
+ "MutationTargetType",
32
+ "get_BD_events",
33
+ "get_BDEI_events",
34
+ "get_BDSS_events",
35
+ "get_canonical_events",
36
+ "get_contact_tracing_events",
37
+ "get_epidemiological_events",
38
+ "get_FBD_events",
39
+ ]
@@ -0,0 +1,125 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Sequence
3
+
4
+ import numpy as np
5
+ from numpy.random import Generator
6
+
7
+ from phylogenie.skyline import SkylineParameterLike, skyline_parameter
8
+ from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
9
+ from phylogenie.treesimulator.model import Event, Model
10
+
11
+ CT_POSTFIX = "-CT"
12
+ CONTACTS_KEY = "CONTACTS"
13
+
14
+
15
+ def _get_CT_state(state: str) -> str:
16
+ return f"{state}{CT_POSTFIX}"
17
+
18
+
19
+ def _is_CT_state(state: str) -> bool:
20
+ return state.endswith(CT_POSTFIX)
21
+
22
+
23
+ class BirthWithContactTracing(Event):
24
+ def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
25
+ super().__init__(state, rate)
26
+ self.child_state = child_state
27
+
28
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
29
+ individual = self.draw_individual(model, rng)
30
+ new_individual = model.birth_from(individual, self.child_state, time)
31
+ if CONTACTS_KEY not in model.context:
32
+ model.context[CONTACTS_KEY] = defaultdict(list)
33
+ model.context[CONTACTS_KEY][individual].append(new_individual)
34
+ model.context[CONTACTS_KEY][new_individual].append(individual)
35
+
36
+ def __repr__(self) -> str:
37
+ return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
38
+
39
+
40
+ class SamplingWithContactTracing(Event):
41
+ def __init__(
42
+ self,
43
+ state: str,
44
+ rate: SkylineParameterLike,
45
+ max_notified_contacts: int,
46
+ notification_probability: SkylineParameterLike,
47
+ ):
48
+ super().__init__(state, rate)
49
+ self.max_notified_contacts = max_notified_contacts
50
+ self.notification_probability = skyline_parameter(notification_probability)
51
+
52
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
53
+ individual = self.draw_individual(model, rng)
54
+ model.sample(individual, time, True)
55
+ population = model.get_population()
56
+ if CONTACTS_KEY not in model.context:
57
+ return
58
+ contacts = model.context[CONTACTS_KEY][individual]
59
+ for contact in contacts[-self.max_notified_contacts :]:
60
+ if contact in population:
61
+ state = model.get_state(contact)
62
+ p = self.notification_probability.get_value_at_time(time)
63
+ if not _is_CT_state(state) and rng.random() < p:
64
+ model.migrate(contact, _get_CT_state(state), time)
65
+
66
+ def __repr__(self) -> str:
67
+ return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
68
+
69
+
70
+ def get_contact_tracing_events(
71
+ events: Sequence[Event],
72
+ max_notified_contacts: int = 1,
73
+ notification_probability: SkylineParameterLike = 1,
74
+ sampling_rate_after_notification: SkylineParameterLike = np.inf,
75
+ samplable_states_after_notification: list[str] | None = None,
76
+ ) -> list[Event]:
77
+ ct_events: list[Event] = []
78
+ notification_probability = skyline_parameter(notification_probability)
79
+ sampling_rate_after_notification = skyline_parameter(
80
+ sampling_rate_after_notification
81
+ )
82
+ for event in events:
83
+ state, rate = event.state, event.rate
84
+ if isinstance(event, Migration):
85
+ ct_events.append(event)
86
+ ct_events.append(
87
+ Migration(_get_CT_state(state), rate, _get_CT_state(event.target_state))
88
+ )
89
+ elif isinstance(event, Birth):
90
+ ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
91
+ ct_events.append(
92
+ BirthWithContactTracing(_get_CT_state(state), rate, event.child_state)
93
+ )
94
+ elif isinstance(event, Sampling):
95
+ if not event.removal:
96
+ raise ValueError(
97
+ "Contact tracing requires removal to be set for all sampling events."
98
+ )
99
+ ct_events.append(
100
+ SamplingWithContactTracing(
101
+ state, rate, max_notified_contacts, notification_probability
102
+ )
103
+ )
104
+ elif isinstance(event, Death):
105
+ ct_events.append(event)
106
+ else:
107
+ raise NotImplementedError(
108
+ f"Unsupported event type {type(event)} for contact tracing."
109
+ )
110
+
111
+ for state in (
112
+ samplable_states_after_notification
113
+ if samplable_states_after_notification is not None
114
+ else {e.state for e in events}
115
+ ):
116
+ ct_events.append(
117
+ SamplingWithContactTracing(
118
+ _get_CT_state(state),
119
+ sampling_rate_after_notification,
120
+ max_notified_contacts,
121
+ notification_probability,
122
+ )
123
+ )
124
+
125
+ return ct_events
@@ -1,80 +1,73 @@
1
- from abc import ABC, abstractmethod
2
- from collections.abc import Sequence
3
- from dataclasses import dataclass
4
-
5
- import numpy as np
1
+ from numpy.random import Generator
6
2
 
7
3
  from phylogenie.skyline import (
8
4
  SkylineMatrixCoercible,
9
- SkylineParameter,
10
5
  SkylineParameterLike,
11
6
  SkylineVectorCoercible,
12
7
  skyline_matrix,
13
- skyline_parameter,
14
8
  skyline_vector,
15
9
  )
16
- from phylogenie.treesimulator.model import Model, get_CT_state
10
+ from phylogenie.treesimulator.model import Event, Model
17
11
 
18
12
  INFECTIOUS_STATE = "I"
19
13
  EXPOSED_STATE = "E"
20
14
  SUPERSPREADER_STATE = "S"
21
15
 
22
16
 
23
- @dataclass
24
- class Event(ABC):
25
- rate: SkylineParameter
26
- state: str
17
+ class Birth(Event):
18
+ def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
19
+ super().__init__(state, rate)
20
+ self.child_state = child_state
27
21
 
28
- def get_propensity(self, model: Model, time: float) -> float:
29
- n_individuals = model.count_individuals(self.state)
30
- rate = self.rate.get_value_at_time(time)
31
- if rate == np.inf and not n_individuals:
32
- return 0
33
- return rate * n_individuals
22
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
23
+ individual = self.draw_individual(model, rng)
24
+ model.birth_from(individual, self.child_state, time)
34
25
 
35
- @abstractmethod
36
- def apply(self, model: Model, time: float) -> None: ...
26
+ def __repr__(self) -> str:
27
+ return f"Birth(state={self.state}, rate={self.rate}, child_state={self.child_state})"
37
28
 
38
29
 
39
- @dataclass
40
- class BirthEvent(Event):
41
- child_state: str
30
+ class Death(Event):
31
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
32
+ individual = self.draw_individual(model, rng)
33
+ model.remove(individual, time)
42
34
 
43
- def apply(self, model: Model, time: float) -> None:
44
- individual = model.get_random_individual(self.state)
45
- model.birth_from(individual, self.child_state, time)
35
+ def __repr__(self) -> str:
36
+ return f"Death(state={self.state}, rate={self.rate})"
46
37
 
47
38
 
48
- class DeathEvent(Event):
49
- def apply(self, model: Model, time: float) -> None:
50
- individual = model.get_random_individual(self.state)
51
- model.remove(individual, time)
39
+ class Migration(Event):
40
+ def __init__(self, state: str, rate: SkylineParameterLike, target_state: str):
41
+ super().__init__(state, rate)
42
+ self.target_state = target_state
52
43
 
44
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
45
+ individual = self.draw_individual(model, rng)
46
+ model.migrate(individual, self.target_state, time)
53
47
 
54
- @dataclass
55
- class MigrationEvent(Event):
56
- target_state: str
48
+ def __repr__(self) -> str:
49
+ return f"Migration(state={self.state}, rate={self.rate}, target_state={self.target_state})"
57
50
 
58
- def apply(self, model: Model, time: float) -> None:
59
- individual = model.get_random_individual(self.state)
60
- model.migrate(individual, self.target_state, time)
61
51
 
52
+ class Sampling(Event):
53
+ def __init__(self, state: str, rate: SkylineParameterLike, removal: bool):
54
+ super().__init__(state, rate)
55
+ self.removal = removal
62
56
 
63
- @dataclass
64
- class SamplingEvent(Event):
65
- removal_probability: SkylineParameter
57
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
58
+ individual = self.draw_individual(model, rng)
59
+ model.sample(individual, time, self.removal)
66
60
 
67
- def apply(self, model: Model, time: float) -> None:
68
- individual = model.get_random_individual(self.state)
69
- model.sample(individual, time, self.removal_probability.get_value_at_time(time))
61
+ def __repr__(self) -> str:
62
+ return f"Sampling(state={self.state}, rate={self.rate}, removal={self.removal})"
70
63
 
71
64
 
72
65
  def get_canonical_events(
73
- states: Sequence[str],
66
+ states: list[str],
74
67
  sampling_rates: SkylineVectorCoercible,
68
+ remove_after_sampling: bool,
75
69
  birth_rates: SkylineVectorCoercible = 0,
76
70
  death_rates: SkylineVectorCoercible = 0,
77
- removal_probabilities: SkylineVectorCoercible = 0,
78
71
  migration_rates: SkylineMatrixCoercible | None = None,
79
72
  birth_rates_among_states: SkylineMatrixCoercible | None = None,
80
73
  ) -> list[Event]:
@@ -83,102 +76,94 @@ def get_canonical_events(
83
76
  birth_rates = skyline_vector(birth_rates, N)
84
77
  death_rates = skyline_vector(death_rates, N)
85
78
  sampling_rates = skyline_vector(sampling_rates, N)
86
- removal_probabilities = skyline_vector(removal_probabilities, N)
87
79
 
88
80
  events: list[Event] = []
89
81
  for i, state in enumerate(states):
90
- events.append(BirthEvent(birth_rates[i], state, state))
91
- events.append(DeathEvent(death_rates[i], state))
92
- events.append(SamplingEvent(sampling_rates[i], state, removal_probabilities[i]))
82
+ events.append(Birth(state, birth_rates[i], state))
83
+ events.append(Death(state, death_rates[i]))
84
+ events.append(Sampling(state, sampling_rates[i], remove_after_sampling))
93
85
 
94
86
  if migration_rates is not None:
95
87
  migration_rates = skyline_matrix(migration_rates, N, N - 1)
96
88
  for i, state in enumerate(states):
97
89
  for j, other_state in enumerate([s for s in states if s != state]):
98
- events.append(MigrationEvent(migration_rates[i, j], state, other_state))
90
+ events.append(Migration(state, migration_rates[i, j], other_state))
99
91
 
100
92
  if birth_rates_among_states is not None:
101
93
  birth_rates_among_states = skyline_matrix(birth_rates_among_states, N, N - 1)
102
94
  for i, state in enumerate(states):
103
95
  for j, other_state in enumerate([s for s in states if s != state]):
104
- events.append(
105
- BirthEvent(birth_rates_among_states[i, j], state, other_state)
106
- )
96
+ events.append(Birth(state, birth_rates_among_states[i, j], other_state))
107
97
 
108
98
  return [event for event in events if event.rate]
109
99
 
110
100
 
111
- def get_epidemiological_events(
112
- states: Sequence[str],
101
+ def get_FBD_events(
102
+ states: list[str],
113
103
  sampling_proportions: SkylineVectorCoercible = 1,
114
- reproduction_numbers: SkylineVectorCoercible = 0,
115
- become_uninfectious_rates: SkylineVectorCoercible = 0,
116
- removal_probabilities: SkylineVectorCoercible = 1,
104
+ diversification: SkylineVectorCoercible = 0,
105
+ turnover: SkylineVectorCoercible = 0,
117
106
  migration_rates: SkylineMatrixCoercible | None = None,
118
- reproduction_numbers_among_states: SkylineMatrixCoercible | None = None,
107
+ diversification_between_states: SkylineMatrixCoercible | None = None,
119
108
  ) -> list[Event]:
120
109
  N = len(states)
121
110
 
122
- reproduction_numbers = skyline_vector(reproduction_numbers, N)
123
- become_uninfectious_rates = skyline_vector(become_uninfectious_rates, N)
111
+ diversification = skyline_vector(diversification, N)
112
+ turnover = skyline_vector(turnover, N)
124
113
  sampling_proportions = skyline_vector(sampling_proportions, N)
125
- removal_probabilities = skyline_vector(removal_probabilities, N)
126
114
 
127
- birth_rates = reproduction_numbers * become_uninfectious_rates
128
- sampling_rates = become_uninfectious_rates * sampling_proportions
129
- death_rates = become_uninfectious_rates - removal_probabilities * sampling_rates
115
+ birth_rates = diversification / (1 - turnover)
116
+ death_rates = turnover * birth_rates
117
+ sampling_rates = sampling_proportions * death_rates
130
118
  birth_rates_among_states = (
131
- (
132
- skyline_matrix(reproduction_numbers_among_states, N, N - 1)
133
- * become_uninfectious_rates
134
- )
135
- if reproduction_numbers_among_states is not None
119
+ (skyline_matrix(diversification_between_states, N, N - 1) + death_rates)
120
+ if diversification_between_states is not None
136
121
  else None
137
122
  )
138
123
 
139
124
  return get_canonical_events(
140
125
  states=states,
126
+ sampling_rates=sampling_rates,
127
+ remove_after_sampling=False,
141
128
  birth_rates=birth_rates,
142
129
  death_rates=death_rates,
143
- sampling_rates=sampling_rates,
144
- removal_probabilities=removal_probabilities,
145
130
  migration_rates=migration_rates,
146
131
  birth_rates_among_states=birth_rates_among_states,
147
132
  )
148
133
 
149
134
 
150
- def get_FBD_events(
151
- states: Sequence[str],
152
- diversification: SkylineVectorCoercible = 0,
153
- turnover: SkylineVectorCoercible = 0,
135
+ def get_epidemiological_events(
136
+ states: list[str],
154
137
  sampling_proportions: SkylineVectorCoercible = 1,
155
- removal_probabilities: SkylineVectorCoercible = 0,
138
+ reproduction_numbers: SkylineVectorCoercible = 0,
139
+ become_uninfectious_rates: SkylineVectorCoercible = 0,
156
140
  migration_rates: SkylineMatrixCoercible | None = None,
157
- diversification_between_types: SkylineMatrixCoercible | None = None,
141
+ reproduction_numbers_among_states: SkylineMatrixCoercible | None = None,
158
142
  ) -> list[Event]:
159
143
  N = len(states)
160
144
 
161
- diversification = skyline_vector(diversification, N)
162
- turnover = skyline_vector(turnover, N)
145
+ reproduction_numbers = skyline_vector(reproduction_numbers, N)
146
+ become_uninfectious_rates = skyline_vector(become_uninfectious_rates, N)
163
147
  sampling_proportions = skyline_vector(sampling_proportions, N)
164
- removal_probabilities = skyline_vector(removal_probabilities, N)
165
148
 
166
- birth_rates = diversification / (1 - turnover)
167
- death_rates = turnover * birth_rates
168
- sampling_rates_dividend = 1 - removal_probabilities * sampling_proportions
169
- sampling_rates = sampling_proportions * death_rates / sampling_rates_dividend
149
+ birth_rates = reproduction_numbers * become_uninfectious_rates
150
+ sampling_rates = become_uninfectious_rates * sampling_proportions
151
+ death_rates = become_uninfectious_rates - sampling_rates
170
152
  birth_rates_among_states = (
171
- (skyline_matrix(diversification_between_types, N, N - 1) + death_rates)
172
- if diversification_between_types is not None
153
+ (
154
+ skyline_matrix(reproduction_numbers_among_states, N, N - 1)
155
+ * become_uninfectious_rates
156
+ )
157
+ if reproduction_numbers_among_states is not None
173
158
  else None
174
159
  )
175
160
 
176
161
  return get_canonical_events(
177
162
  states=states,
163
+ sampling_rates=sampling_rates,
164
+ remove_after_sampling=True,
178
165
  birth_rates=birth_rates,
179
166
  death_rates=death_rates,
180
- sampling_rates=sampling_rates,
181
- removal_probabilities=removal_probabilities,
182
167
  migration_rates=migration_rates,
183
168
  birth_rates_among_states=birth_rates_among_states,
184
169
  )
@@ -232,40 +217,3 @@ def get_BDSS_events(
232
217
  become_uninfectious_rates=1 / infectious_period,
233
218
  sampling_proportions=sampling_proportion,
234
219
  )
235
-
236
-
237
- def get_contact_tracing_events(
238
- events: Sequence[Event],
239
- samplable_states_after_notification: Sequence[str] | None = None,
240
- sampling_rate_after_notification: SkylineParameterLike = np.inf,
241
- contacts_removal_probability: SkylineParameterLike = 1,
242
- ) -> list[Event]:
243
- ct_events = list(events)
244
- for event in events:
245
- if isinstance(event, MigrationEvent):
246
- ct_events.append(
247
- MigrationEvent(
248
- event.rate,
249
- get_CT_state(event.state),
250
- get_CT_state(event.target_state),
251
- )
252
- )
253
- elif isinstance(event, BirthEvent):
254
- ct_events.append(
255
- BirthEvent(event.rate, get_CT_state(event.state), event.child_state)
256
- )
257
-
258
- for state in (
259
- samplable_states_after_notification
260
- if samplable_states_after_notification is not None
261
- else [e.state for e in events]
262
- ):
263
- ct_events.append(
264
- SamplingEvent(
265
- skyline_parameter(sampling_rate_after_notification),
266
- get_CT_state(state),
267
- skyline_parameter(contacts_removal_probability),
268
- )
269
- )
270
-
271
- return ct_events
@@ -0,0 +1,105 @@
1
+ from copy import deepcopy
2
+ from enum import Enum
3
+ from typing import Type
4
+
5
+ from numpy.random import Generator
6
+
7
+ from phylogenie.skyline import SkylineParameterLike
8
+ from phylogenie.treesimulator.events.contact_tracing import (
9
+ BirthWithContactTracing,
10
+ SamplingWithContactTracing,
11
+ )
12
+ from phylogenie.treesimulator.events.core import (
13
+ Birth,
14
+ Death,
15
+ Event,
16
+ Migration,
17
+ Sampling,
18
+ )
19
+ from phylogenie.treesimulator.model import Model
20
+ from phylogenie.utils import Distribution
21
+
22
+ MUTATION_PREFIX = "MUT-"
23
+ MUTATIONS_KEY = "MUTATIONS"
24
+
25
+
26
+ def _get_mutation(state: str) -> str | None:
27
+ return state.split(".")[0] if state.startswith(MUTATION_PREFIX) else None
28
+
29
+
30
+ def _get_mutated_state(mutation_id: int, state: str) -> str:
31
+ if state.startswith(MUTATION_PREFIX):
32
+ state = state.split(".")[1]
33
+ return f"{MUTATION_PREFIX}{mutation_id}.{state}"
34
+
35
+
36
+ class TargetType(str, Enum):
37
+ BIRTH = "birth"
38
+ DEATH = "death"
39
+ MIGRATION = "migration"
40
+ SAMPLING = "sampling"
41
+ MUTATION = "mutation"
42
+
43
+
44
+ class Mutation(Event):
45
+ def __init__(
46
+ self,
47
+ state: str,
48
+ rate: SkylineParameterLike,
49
+ rate_scalers: dict[TargetType, Distribution],
50
+ ):
51
+ super().__init__(state, rate)
52
+ self.rate_scalers = rate_scalers
53
+
54
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
55
+ if MUTATIONS_KEY not in model.context:
56
+ model.context[MUTATIONS_KEY] = 0
57
+ model.context[MUTATIONS_KEY] += 1
58
+ mutation_id = model.context[MUTATIONS_KEY]
59
+
60
+ individual = self.draw_individual(model, rng)
61
+ model.migrate(individual, _get_mutated_state(mutation_id, self.state), time)
62
+
63
+ rate_scalers = {
64
+ target_type: getattr(rng, rate_scaler.type)(**rate_scaler.args)
65
+ for target_type, rate_scaler in self.rate_scalers.items()
66
+ }
67
+
68
+ for event in [
69
+ deepcopy(e)
70
+ for e in model.events
71
+ if _get_mutation(self.state) == _get_mutation(e.state)
72
+ ]:
73
+ event.state = _get_mutated_state(mutation_id, event.state)
74
+ if isinstance(event, Birth | BirthWithContactTracing):
75
+ event.child_state = _get_mutated_state(mutation_id, event.child_state)
76
+ elif isinstance(event, Migration):
77
+ event.target_state = _get_mutated_state(mutation_id, event.target_state)
78
+ elif not isinstance(
79
+ event, Mutation | Death | Sampling | SamplingWithContactTracing
80
+ ):
81
+ raise ValueError(
82
+ f"Mutation not defined for event of type {type(event)}."
83
+ )
84
+
85
+ for target_type, rate_scaler in rate_scalers.items():
86
+ if target_type not in TARGETS:
87
+ raise ValueError(
88
+ f"Unsupported target type {target_type} for mutation."
89
+ )
90
+ if isinstance(event, TARGETS[target_type]):
91
+ event.rate *= rate_scaler
92
+
93
+ model.add_event(event)
94
+
95
+ def __repr__(self) -> str:
96
+ return f"Mutation(state={self.state}, rate={self.rate}, rate_scalers={self.rate_scalers})"
97
+
98
+
99
+ TARGETS: dict[TargetType, tuple[Type[Event], ...]] = {
100
+ TargetType.BIRTH: (Birth, BirthWithContactTracing),
101
+ TargetType.DEATH: (Death,),
102
+ TargetType.MIGRATION: (Migration,),
103
+ TargetType.SAMPLING: (Sampling, SamplingWithContactTracing),
104
+ TargetType.MUTATION: (Mutation,),
105
+ }