phylogenie 2.1.4__py3-none-any.whl → 3.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. phylogenie/__init__.py +60 -14
  2. phylogenie/draw.py +690 -0
  3. phylogenie/generators/alisim.py +12 -12
  4. phylogenie/generators/configs.py +26 -4
  5. phylogenie/generators/dataset.py +3 -3
  6. phylogenie/generators/factories.py +38 -12
  7. phylogenie/generators/trees.py +48 -47
  8. phylogenie/io/__init__.py +3 -0
  9. phylogenie/io/fasta.py +34 -0
  10. phylogenie/main.py +27 -10
  11. phylogenie/mixins.py +33 -0
  12. phylogenie/skyline/matrix.py +11 -7
  13. phylogenie/skyline/parameter.py +12 -4
  14. phylogenie/skyline/vector.py +12 -6
  15. phylogenie/treesimulator/__init__.py +36 -3
  16. phylogenie/treesimulator/events/__init__.py +5 -5
  17. phylogenie/treesimulator/events/base.py +39 -0
  18. phylogenie/treesimulator/events/contact_tracing.py +38 -23
  19. phylogenie/treesimulator/events/core.py +21 -12
  20. phylogenie/treesimulator/events/mutations.py +46 -46
  21. phylogenie/treesimulator/features.py +49 -0
  22. phylogenie/treesimulator/gillespie.py +59 -55
  23. phylogenie/treesimulator/io/__init__.py +4 -0
  24. phylogenie/treesimulator/io/newick.py +104 -0
  25. phylogenie/treesimulator/io/nexus.py +50 -0
  26. phylogenie/treesimulator/model.py +25 -49
  27. phylogenie/treesimulator/tree.py +196 -0
  28. phylogenie/treesimulator/utils.py +108 -0
  29. phylogenie/typings.py +3 -3
  30. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/METADATA +13 -15
  31. phylogenie-3.1.7.dist-info/RECORD +41 -0
  32. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/WHEEL +2 -1
  33. phylogenie-3.1.7.dist-info/entry_points.txt +2 -0
  34. phylogenie-3.1.7.dist-info/top_level.txt +1 -0
  35. phylogenie/io.py +0 -107
  36. phylogenie/tree.py +0 -92
  37. phylogenie/utils.py +0 -17
  38. phylogenie-2.1.4.dist-info/RECORD +0 -32
  39. phylogenie-2.1.4.dist-info/entry_points.txt +0 -3
  40. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info/licenses}/LICENSE.txt +0 -0
@@ -47,7 +47,7 @@ class SkylineVector:
47
47
  ):
48
48
  if params is not None and value is None and change_times is None:
49
49
  if is_many_skyline_parameters_like(params):
50
- self.params = [skyline_parameter(param) for param in params]
50
+ self._params = [skyline_parameter(param) for param in params]
51
51
  else:
52
52
  raise TypeError(
53
53
  f"It is impossible to create a SkylineVector from `params` {params} of type {type(params)}. Please provide a sequence of SkylineParameterLike objects (a SkylineParameterLike object can either be a SkylineParameter or a scalar)."
@@ -63,7 +63,7 @@ class SkylineVector:
63
63
  raise TypeError(
64
64
  f"It is impossible to create a SkylineVector from `value` {value} of type {type(value)}. Please provide a nested (2D) sequence of scalar values."
65
65
  )
66
- self.params = [
66
+ self._params = [
67
67
  SkylineParameter([vector[i] for vector in value], change_times)
68
68
  for i in range(len(value[0]))
69
69
  ]
@@ -72,20 +72,26 @@ class SkylineVector:
72
72
  "Either `params` or both `value` and `change_times` must be provided to create a SkylineVector."
73
73
  )
74
74
 
75
+ @property
76
+ def params(self) -> tuple[SkylineParameter, ...]:
77
+ return tuple(self._params)
78
+
75
79
  @property
76
80
  def change_times(self) -> pgt.Vector1D:
77
- return sorted(set(t for param in self.params for t in param.change_times))
81
+ return tuple(
82
+ sorted(set(t for param in self.params for t in param.change_times))
83
+ )
78
84
 
79
85
  @property
80
86
  def value(self) -> pgt.Vector2D:
81
- return [self.get_value_at_time(t) for t in (0, *self.change_times)]
87
+ return tuple(self.get_value_at_time(t) for t in (0, *self.change_times))
82
88
 
83
89
  @property
84
90
  def N(self) -> int:
85
91
  return len(self.params)
86
92
 
87
93
  def get_value_at_time(self, t: pgt.Scalar) -> pgt.Vector1D:
88
- return [param.get_value_at_time(t) for param in self.params]
94
+ return tuple(param.get_value_at_time(t) for param in self.params)
89
95
 
90
96
  def _operate(
91
97
  self,
@@ -154,7 +160,7 @@ class SkylineVector:
154
160
  raise TypeError(
155
161
  f"It is impossible to set item {item} of SkylineVector with value {value} of type {type(value)}. Please provide a SkylineParameterLike object (i.e., a scalar or a SkylineParameter)."
156
162
  )
157
- self.params[item] = skyline_parameter(value)
163
+ self._params[item] = skyline_parameter(value)
158
164
 
159
165
 
160
166
  def skyline_vector(x: SkylineVectorCoercible, N: int) -> SkylineVector:
@@ -3,9 +3,9 @@ from phylogenie.treesimulator.events import (
3
3
  BirthWithContactTracing,
4
4
  Death,
5
5
  Event,
6
+ EventType,
6
7
  Migration,
7
8
  Mutation,
8
- MutationTargetType,
9
9
  Sampling,
10
10
  SamplingWithContactTracing,
11
11
  get_BD_events,
@@ -15,17 +15,33 @@ from phylogenie.treesimulator.events import (
15
15
  get_contact_tracing_events,
16
16
  get_epidemiological_events,
17
17
  get_FBD_events,
18
+ get_mutation_id,
18
19
  )
20
+ from phylogenie.treesimulator.features import Feature, set_features
19
21
  from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
22
+ from phylogenie.treesimulator.io import dump_newick, load_newick, load_nexus
23
+ from phylogenie.treesimulator.model import get_node_state
24
+ from phylogenie.treesimulator.tree import Tree
25
+ from phylogenie.treesimulator.utils import (
26
+ compute_mean_leaf_pairwise_distance,
27
+ compute_sackin_index,
28
+ get_distance,
29
+ get_mrca,
30
+ get_node_depth_levels,
31
+ get_node_depths,
32
+ get_node_height_levels,
33
+ get_node_heights,
34
+ get_node_leaf_counts,
35
+ )
20
36
 
21
37
  __all__ = [
22
38
  "Birth",
23
39
  "BirthWithContactTracing",
24
40
  "Death",
25
41
  "Event",
42
+ "EventType",
26
43
  "Migration",
27
44
  "Mutation",
28
- "MutationTargetType",
29
45
  "Sampling",
30
46
  "SamplingWithContactTracing",
31
47
  "get_BD_events",
@@ -35,6 +51,23 @@ __all__ = [
35
51
  "get_contact_tracing_events",
36
52
  "get_epidemiological_events",
37
53
  "get_FBD_events",
38
- "generate_trees",
54
+ "get_mutation_id",
55
+ "Feature",
56
+ "set_features",
39
57
  "simulate_tree",
58
+ "dump_newick",
59
+ "load_newick",
60
+ "load_nexus",
61
+ "generate_trees",
62
+ "get_node_state",
63
+ "Tree",
64
+ "compute_mean_leaf_pairwise_distance",
65
+ "compute_sackin_index",
66
+ "get_distance",
67
+ "get_mrca",
68
+ "get_node_depth_levels",
69
+ "get_node_depths",
70
+ "get_node_height_levels",
71
+ "get_node_heights",
72
+ "get_node_leaf_counts",
40
73
  ]
@@ -1,3 +1,4 @@
1
+ from phylogenie.treesimulator.events.base import Event, EventType
1
2
  from phylogenie.treesimulator.events.contact_tracing import (
2
3
  BirthWithContactTracing,
3
4
  SamplingWithContactTracing,
@@ -6,7 +7,6 @@ from phylogenie.treesimulator.events.contact_tracing import (
6
7
  from phylogenie.treesimulator.events.core import (
7
8
  Birth,
8
9
  Death,
9
- Event,
10
10
  Migration,
11
11
  Sampling,
12
12
  get_BD_events,
@@ -16,19 +16,17 @@ from phylogenie.treesimulator.events.core import (
16
16
  get_epidemiological_events,
17
17
  get_FBD_events,
18
18
  )
19
- from phylogenie.treesimulator.events.mutations import Mutation
20
- from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
19
+ from phylogenie.treesimulator.events.mutations import Mutation, get_mutation_id
21
20
 
22
21
  __all__ = [
23
22
  "Birth",
24
23
  "BirthWithContactTracing",
25
24
  "Death",
26
25
  "Event",
26
+ "EventType",
27
27
  "Migration",
28
- "Mutation",
29
28
  "Sampling",
30
29
  "SamplingWithContactTracing",
31
- "MutationTargetType",
32
30
  "get_BD_events",
33
31
  "get_BDEI_events",
34
32
  "get_BDSS_events",
@@ -36,4 +34,6 @@ __all__ = [
36
34
  "get_contact_tracing_events",
37
35
  "get_epidemiological_events",
38
36
  "get_FBD_events",
37
+ "Mutation",
38
+ "get_mutation_id",
39
39
  ]
@@ -0,0 +1,39 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import Any
4
+
5
+ from numpy.random import Generator
6
+
7
+ from phylogenie.skyline import SkylineParameterLike, skyline_parameter
8
+ from phylogenie.treesimulator.model import Model
9
+
10
+
11
+ class EventType(str, Enum):
12
+ BIRTH = "birth"
13
+ DEATH = "death"
14
+ MIGRATION = "migration"
15
+ SAMPLING = "sampling"
16
+ MUTATION = "mutation"
17
+
18
+
19
+ class Event(ABC):
20
+ type: EventType
21
+
22
+ def __init__(self, state: str, rate: SkylineParameterLike):
23
+ self.state = state
24
+ self.rate = skyline_parameter(rate)
25
+ if any(v < 0 for v in self.rate.value):
26
+ raise ValueError("Event rates must be non-negative.")
27
+
28
+ def draw_individual(self, model: Model, rng: Generator) -> int:
29
+ return rng.choice(model.get_population(self.state))
30
+
31
+ def get_propensity(self, model: Model, time: float) -> float:
32
+ n_individuals = model.count_individuals(self.state)
33
+ rate = self.rate.get_value_at_time(time)
34
+ return rate * n_individuals
35
+
36
+ @abstractmethod
37
+ def apply(
38
+ self, model: Model, events: "list[Event]", time: float, rng: Generator
39
+ ) -> dict[str, Any] | None: ...
@@ -1,43 +1,48 @@
1
1
  from collections import defaultdict
2
2
  from collections.abc import Sequence
3
+ from copy import deepcopy
3
4
 
4
- import numpy as np
5
5
  from numpy.random import Generator
6
6
 
7
7
  from phylogenie.skyline import SkylineParameterLike, skyline_parameter
8
+ from phylogenie.treesimulator.events.base import Event, EventType
8
9
  from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
9
- from phylogenie.treesimulator.model import Event, Model
10
+ from phylogenie.treesimulator.model import Model
10
11
 
11
12
  CT_POSTFIX = "-CT"
12
13
  CONTACTS_KEY = "CONTACTS"
13
14
 
14
15
 
15
- def _get_CT_state(state: str) -> str:
16
+ def get_CT_state(state: str) -> str:
16
17
  return f"{state}{CT_POSTFIX}"
17
18
 
18
19
 
19
- def _is_CT_state(state: str) -> bool:
20
+ def is_CT_state(state: str) -> bool:
20
21
  return state.endswith(CT_POSTFIX)
21
22
 
22
23
 
23
24
  class BirthWithContactTracing(Event):
25
+ type = EventType.BIRTH
26
+
24
27
  def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
25
28
  super().__init__(state, rate)
26
29
  self.child_state = child_state
27
30
 
28
- def apply(self, model: Model, time: float, rng: Generator) -> None:
31
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
29
32
  individual = self.draw_individual(model, rng)
30
33
  new_individual = model.birth_from(individual, self.child_state, time)
31
- if CONTACTS_KEY not in model.context:
32
- model.context[CONTACTS_KEY] = defaultdict(list)
33
- model.context[CONTACTS_KEY][individual].append(new_individual)
34
- model.context[CONTACTS_KEY][new_individual].append(individual)
34
+ if CONTACTS_KEY not in model.metadata:
35
+ model[CONTACTS_KEY] = defaultdict(list)
36
+ model[CONTACTS_KEY][individual].append(new_individual)
37
+ model[CONTACTS_KEY][new_individual].append(individual)
35
38
 
36
39
  def __repr__(self) -> str:
37
40
  return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
38
41
 
39
42
 
40
43
  class SamplingWithContactTracing(Event):
44
+ type = EventType.SAMPLING
45
+
41
46
  def __init__(
42
47
  self,
43
48
  state: str,
@@ -49,19 +54,19 @@ class SamplingWithContactTracing(Event):
49
54
  self.max_notified_contacts = max_notified_contacts
50
55
  self.notification_probability = skyline_parameter(notification_probability)
51
56
 
52
- def apply(self, model: Model, time: float, rng: Generator) -> None:
57
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
53
58
  individual = self.draw_individual(model, rng)
54
59
  model.sample(individual, time, True)
55
60
  population = model.get_population()
56
- if CONTACTS_KEY not in model.context:
61
+ if CONTACTS_KEY not in model.metadata:
57
62
  return
58
- contacts = model.context[CONTACTS_KEY][individual]
63
+ contacts = model[CONTACTS_KEY][individual]
59
64
  for contact in contacts[-self.max_notified_contacts :]:
60
65
  if contact in population:
61
66
  state = model.get_state(contact)
62
67
  p = self.notification_probability.get_value_at_time(time)
63
- if not _is_CT_state(state) and rng.random() < p:
64
- model.migrate(contact, _get_CT_state(state), time)
68
+ if not is_CT_state(state) and rng.random() < p:
69
+ model.migrate(contact, get_CT_state(state), time)
65
70
 
66
71
  def __repr__(self) -> str:
67
72
  return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
@@ -70,8 +75,8 @@ class SamplingWithContactTracing(Event):
70
75
  def get_contact_tracing_events(
71
76
  events: Sequence[Event],
72
77
  max_notified_contacts: int = 1,
73
- notification_probability: SkylineParameterLike = 1,
74
- sampling_rate_after_notification: SkylineParameterLike = np.inf,
78
+ notification_probability: SkylineParameterLike = 0.0,
79
+ sampling_rate_after_notification: SkylineParameterLike = 2**32,
75
80
  samplable_states_after_notification: list[str] | None = None,
76
81
  ) -> list[Event]:
77
82
  ct_events: list[Event] = []
@@ -79,17 +84,24 @@ def get_contact_tracing_events(
79
84
  sampling_rate_after_notification = skyline_parameter(
80
85
  sampling_rate_after_notification
81
86
  )
82
- for event in events:
83
- state, rate = event.state, event.rate
87
+ for event in [deepcopy(e) for e in events]:
84
88
  if isinstance(event, Migration):
85
89
  ct_events.append(event)
86
90
  ct_events.append(
87
- Migration(_get_CT_state(state), rate, _get_CT_state(event.target_state))
91
+ Migration(
92
+ get_CT_state(event.state),
93
+ event.rate,
94
+ get_CT_state(event.target_state),
95
+ )
88
96
  )
89
97
  elif isinstance(event, Birth):
90
- ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
91
98
  ct_events.append(
92
- BirthWithContactTracing(_get_CT_state(state), rate, event.child_state)
99
+ BirthWithContactTracing(event.state, event.rate, event.child_state)
100
+ )
101
+ ct_events.append(
102
+ BirthWithContactTracing(
103
+ get_CT_state(event.state), event.rate, event.child_state
104
+ )
93
105
  )
94
106
  elif isinstance(event, Sampling):
95
107
  if not event.removal:
@@ -98,7 +110,10 @@ def get_contact_tracing_events(
98
110
  )
99
111
  ct_events.append(
100
112
  SamplingWithContactTracing(
101
- state, rate, max_notified_contacts, notification_probability
113
+ event.state,
114
+ event.rate,
115
+ max_notified_contacts,
116
+ notification_probability,
102
117
  )
103
118
  )
104
119
  elif isinstance(event, Death):
@@ -115,7 +130,7 @@ def get_contact_tracing_events(
115
130
  ):
116
131
  ct_events.append(
117
132
  SamplingWithContactTracing(
118
- _get_CT_state(state),
133
+ get_CT_state(state),
119
134
  sampling_rate_after_notification,
120
135
  max_notified_contacts,
121
136
  notification_probability,
@@ -7,7 +7,8 @@ from phylogenie.skyline import (
7
7
  skyline_matrix,
8
8
  skyline_vector,
9
9
  )
10
- from phylogenie.treesimulator.model import Event, Model
10
+ from phylogenie.treesimulator.events.base import Event, EventType
11
+ from phylogenie.treesimulator.model import Model
11
12
 
12
13
  INFECTIOUS_STATE = "I"
13
14
  EXPOSED_STATE = "E"
@@ -15,11 +16,13 @@ SUPERSPREADER_STATE = "S"
15
16
 
16
17
 
17
18
  class Birth(Event):
19
+ type = EventType.BIRTH
20
+
18
21
  def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
19
22
  super().__init__(state, rate)
20
23
  self.child_state = child_state
21
24
 
22
- def apply(self, model: Model, time: float, rng: Generator) -> None:
25
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
23
26
  individual = self.draw_individual(model, rng)
24
27
  model.birth_from(individual, self.child_state, time)
25
28
 
@@ -28,7 +31,9 @@ class Birth(Event):
28
31
 
29
32
 
30
33
  class Death(Event):
31
- def apply(self, model: Model, time: float, rng: Generator) -> None:
34
+ type = EventType.DEATH
35
+
36
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
32
37
  individual = self.draw_individual(model, rng)
33
38
  model.remove(individual, time)
34
39
 
@@ -37,11 +42,13 @@ class Death(Event):
37
42
 
38
43
 
39
44
  class Migration(Event):
45
+ type = EventType.MIGRATION
46
+
40
47
  def __init__(self, state: str, rate: SkylineParameterLike, target_state: str):
41
48
  super().__init__(state, rate)
42
49
  self.target_state = target_state
43
50
 
44
- def apply(self, model: Model, time: float, rng: Generator) -> None:
51
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
45
52
  individual = self.draw_individual(model, rng)
46
53
  model.migrate(individual, self.target_state, time)
47
54
 
@@ -50,11 +57,13 @@ class Migration(Event):
50
57
 
51
58
 
52
59
  class Sampling(Event):
60
+ type = EventType.SAMPLING
61
+
53
62
  def __init__(self, state: str, rate: SkylineParameterLike, removal: bool):
54
63
  super().__init__(state, rate)
55
64
  self.removal = removal
56
65
 
57
- def apply(self, model: Model, time: float, rng: Generator) -> None:
66
+ def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
58
67
  individual = self.draw_individual(model, rng)
59
68
  model.sample(individual, time, self.removal)
60
69
 
@@ -64,8 +73,8 @@ class Sampling(Event):
64
73
 
65
74
  def get_canonical_events(
66
75
  states: list[str],
67
- sampling_rates: SkylineVectorCoercible,
68
- remove_after_sampling: bool,
76
+ sampling_rates: SkylineVectorCoercible = 0,
77
+ remove_after_sampling: bool = False,
69
78
  birth_rates: SkylineVectorCoercible = 0,
70
79
  death_rates: SkylineVectorCoercible = 0,
71
80
  migration_rates: SkylineMatrixCoercible | None = None,
@@ -100,7 +109,7 @@ def get_canonical_events(
100
109
 
101
110
  def get_FBD_events(
102
111
  states: list[str],
103
- sampling_proportions: SkylineVectorCoercible = 1,
112
+ sampling_proportions: SkylineVectorCoercible = 0,
104
113
  diversification: SkylineVectorCoercible = 0,
105
114
  turnover: SkylineVectorCoercible = 0,
106
115
  migration_rates: SkylineMatrixCoercible | None = None,
@@ -134,7 +143,7 @@ def get_FBD_events(
134
143
 
135
144
  def get_epidemiological_events(
136
145
  states: list[str],
137
- sampling_proportions: SkylineVectorCoercible = 1,
146
+ sampling_proportions: SkylineVectorCoercible,
138
147
  reproduction_numbers: SkylineVectorCoercible = 0,
139
148
  become_uninfectious_rates: SkylineVectorCoercible = 0,
140
149
  migration_rates: SkylineMatrixCoercible | None = None,
@@ -172,7 +181,7 @@ def get_epidemiological_events(
172
181
  def get_BD_events(
173
182
  reproduction_number: SkylineParameterLike,
174
183
  infectious_period: SkylineParameterLike,
175
- sampling_proportion: SkylineParameterLike = 1,
184
+ sampling_proportion: SkylineParameterLike,
176
185
  ) -> list[Event]:
177
186
  return get_epidemiological_events(
178
187
  states=[INFECTIOUS_STATE],
@@ -186,7 +195,7 @@ def get_BDEI_events(
186
195
  reproduction_number: SkylineParameterLike,
187
196
  infectious_period: SkylineParameterLike,
188
197
  incubation_period: SkylineParameterLike,
189
- sampling_proportion: SkylineParameterLike = 1,
198
+ sampling_proportion: SkylineParameterLike,
190
199
  ) -> list[Event]:
191
200
  return get_epidemiological_events(
192
201
  states=[EXPOSED_STATE, INFECTIOUS_STATE],
@@ -202,7 +211,7 @@ def get_BDSS_events(
202
211
  infectious_period: SkylineParameterLike,
203
212
  superspreading_ratio: SkylineParameterLike,
204
213
  superspreaders_proportion: SkylineParameterLike,
205
- sampling_proportion: SkylineParameterLike = 1,
214
+ sampling_proportion: SkylineParameterLike,
206
215
  ) -> list[Event]:
207
216
  f_SS = superspreaders_proportion
208
217
  r_SS = superspreading_ratio
@@ -1,26 +1,20 @@
1
+ import re
1
2
  from copy import deepcopy
2
- from enum import Enum
3
- from typing import Type
3
+ from typing import Any, Callable
4
4
 
5
5
  from numpy.random import Generator
6
6
 
7
7
  from phylogenie.skyline import SkylineParameterLike
8
+ from phylogenie.treesimulator.events.base import Event, EventType
8
9
  from phylogenie.treesimulator.events.contact_tracing import (
9
10
  BirthWithContactTracing,
10
11
  SamplingWithContactTracing,
11
12
  )
12
- from phylogenie.treesimulator.events.core import (
13
- Birth,
14
- Death,
15
- Event,
16
- Migration,
17
- Sampling,
18
- )
13
+ from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
19
14
  from phylogenie.treesimulator.model import Model
20
- from phylogenie.utils import Distribution
21
15
 
22
16
  MUTATION_PREFIX = "MUT-"
23
- MUTATIONS_KEY = "MUTATIONS"
17
+ NEXT_MUTATION_ID = "NEXT_MUTATION_ID"
24
18
 
25
19
 
26
20
  def _get_mutation(state: str) -> str | None:
@@ -29,77 +23,83 @@ def _get_mutation(state: str) -> str | None:
29
23
 
30
24
  def _get_mutated_state(mutation_id: int, state: str) -> str:
31
25
  if state.startswith(MUTATION_PREFIX):
32
- state = state.split(".")[1]
26
+ _, state = state.split(".")
33
27
  return f"{MUTATION_PREFIX}{mutation_id}.{state}"
34
28
 
35
29
 
36
- class TargetType(str, Enum):
37
- BIRTH = "birth"
38
- DEATH = "death"
39
- MIGRATION = "migration"
40
- SAMPLING = "sampling"
41
- MUTATION = "mutation"
30
+ def get_mutation_id(node_name: str) -> int:
31
+ match = re.search(rf"{MUTATION_PREFIX}(\d+)\.", node_name)
32
+ if match:
33
+ return int(match.group(1))
34
+ return 0
42
35
 
43
36
 
44
37
  class Mutation(Event):
38
+ type = EventType.MUTATION
39
+
45
40
  def __init__(
46
41
  self,
47
42
  state: str,
48
43
  rate: SkylineParameterLike,
49
- rate_scalers: dict[TargetType, Distribution],
44
+ rate_scalers: dict[EventType, Callable[[], float]],
45
+ rates_to_log: list[EventType] | None = None,
50
46
  ):
51
47
  super().__init__(state, rate)
52
48
  self.rate_scalers = rate_scalers
49
+ self.rates_to_log = [] if rates_to_log is None else rates_to_log
53
50
 
54
- def apply(self, model: Model, time: float, rng: Generator) -> None:
55
- if MUTATIONS_KEY not in model.context:
56
- model.context[MUTATIONS_KEY] = 0
57
- model.context[MUTATIONS_KEY] += 1
58
- mutation_id = model.context[MUTATIONS_KEY]
51
+ def apply(
52
+ self, model: Model, events: list[Event], time: float, rng: Generator
53
+ ) -> dict[str, Any]:
54
+ if NEXT_MUTATION_ID not in model.metadata:
55
+ model[NEXT_MUTATION_ID] = 0
56
+ model[NEXT_MUTATION_ID] += 1
57
+ mutation_id = model[NEXT_MUTATION_ID]
59
58
 
60
59
  individual = self.draw_individual(model, rng)
61
60
  model.migrate(individual, _get_mutated_state(mutation_id, self.state), time)
62
61
 
63
- rate_scalers = {
64
- target_type: getattr(rng, rate_scaler.type)(**rate_scaler.args)
62
+ rate_scalers: dict[EventType, float] = {
63
+ target_type: rate_scaler()
65
64
  for target_type, rate_scaler in self.rate_scalers.items()
66
65
  }
67
66
 
67
+ metadata: dict[str, Any] = {}
68
68
  for event in [
69
69
  deepcopy(e)
70
- for e in model.events
70
+ for e in events
71
71
  if _get_mutation(self.state) == _get_mutation(e.state)
72
72
  ]:
73
73
  event.state = _get_mutated_state(mutation_id, event.state)
74
+
74
75
  if isinstance(event, Birth | BirthWithContactTracing):
75
76
  event.child_state = _get_mutated_state(mutation_id, event.child_state)
77
+ metadata_key = f"birth_rate_from_{event.state}_to_{event.child_state}"
76
78
  elif isinstance(event, Migration):
77
79
  event.target_state = _get_mutated_state(mutation_id, event.target_state)
78
- elif not isinstance(
80
+ metadata_key = (
81
+ f"migration_rate_from_{event.state}_to_{event.target_state}"
82
+ )
83
+ elif isinstance(
79
84
  event, Mutation | Death | Sampling | SamplingWithContactTracing
80
85
  ):
86
+ metadata_key = f"{event.type}_rate_for_{event.state}"
87
+ else:
81
88
  raise ValueError(
82
- f"Mutation not defined for event of type {type(event)}."
89
+ f"Mutation not implemented for event of type {type(event)}."
83
90
  )
84
91
 
85
- for target_type, rate_scaler in rate_scalers.items():
86
- if target_type not in TARGETS:
87
- raise ValueError(
88
- f"Unsupported target type {target_type} for mutation."
89
- )
90
- if isinstance(event, TARGETS[target_type]):
91
- event.rate *= rate_scaler
92
+ event.rate *= rate_scalers.get(event.type, 1)
93
+ if event.type in self.rates_to_log:
94
+ metadata[metadata_key] = (
95
+ event.rate.value[0]
96
+ if len(event.rate.value) == 1
97
+ else list(event.rate.value)
98
+ )
99
+
100
+ events.append(event)
92
101
 
93
- model.add_event(event)
102
+ return metadata
94
103
 
95
104
  def __repr__(self) -> str:
96
105
  return f"Mutation(state={self.state}, rate={self.rate}, rate_scalers={self.rate_scalers})"
97
-
98
-
99
- TARGETS: dict[TargetType, tuple[Type[Event], ...]] = {
100
- TargetType.BIRTH: (Birth, BirthWithContactTracing),
101
- TargetType.DEATH: (Death,),
102
- TargetType.MIGRATION: (Migration,),
103
- TargetType.SAMPLING: (Sampling, SamplingWithContactTracing),
104
- TargetType.MUTATION: (Mutation,),
105
- }
@@ -0,0 +1,49 @@
1
+ from collections.abc import Iterable
2
+ from enum import Enum
3
+
4
+ from phylogenie.treesimulator.events.mutations import get_mutation_id
5
+ from phylogenie.treesimulator.model import get_node_state
6
+ from phylogenie.treesimulator.tree import Tree
7
+ from phylogenie.treesimulator.utils import (
8
+ get_node_depth_levels,
9
+ get_node_depths,
10
+ get_node_height_levels,
11
+ get_node_heights,
12
+ get_node_leaf_counts,
13
+ )
14
+
15
+
16
+ def _get_states(tree: Tree) -> dict[Tree, str]:
17
+ return {node: get_node_state(node.name) for node in tree}
18
+
19
+
20
+ def _get_mutations(tree: Tree) -> dict[Tree, int]:
21
+ return {node: get_mutation_id(node.name) for node in tree}
22
+
23
+
24
+ class Feature(str, Enum):
25
+ DEPTH = "depth"
26
+ DEPTH_LEVEL = "depth_level"
27
+ HEIGHT = "height"
28
+ HEIGHT_LEVEL = "height_level"
29
+ MUTATION = "mutation"
30
+ N_LEAVES = "n_leaves"
31
+ STATE = "state"
32
+
33
+
34
+ FEATURES_EXTRACTORS = {
35
+ Feature.DEPTH: get_node_depths,
36
+ Feature.DEPTH_LEVEL: get_node_depth_levels,
37
+ Feature.HEIGHT: get_node_heights,
38
+ Feature.HEIGHT_LEVEL: get_node_height_levels,
39
+ Feature.MUTATION: _get_mutations,
40
+ Feature.N_LEAVES: get_node_leaf_counts,
41
+ Feature.STATE: _get_states,
42
+ }
43
+
44
+
45
+ def set_features(tree: Tree, features: Iterable[Feature]) -> None:
46
+ for feature in features:
47
+ feature_maps = FEATURES_EXTRACTORS[feature](tree)
48
+ for node in tree:
49
+ node[feature.value] = feature_maps[node]