phylogenie 2.1.0__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of phylogenie might be problematic. Click here for more details.

phylogenie/__init__.py CHANGED
@@ -10,7 +10,7 @@ from phylogenie.generators import (
10
10
  FBDTreeDatasetGenerator,
11
11
  TreeDatasetGeneratorConfig,
12
12
  )
13
- from phylogenie.io import load_fasta, load_newick
13
+ from phylogenie.io import dump_newick, load_fasta, load_newick
14
14
  from phylogenie.msa import MSA
15
15
  from phylogenie.skyline import (
16
16
  SkylineMatrix,
@@ -26,12 +26,21 @@ from phylogenie.skyline import (
26
26
  )
27
27
  from phylogenie.tree import Tree
28
28
  from phylogenie.treesimulator import (
29
+ Birth,
30
+ BirthWithContactTracing,
31
+ Death,
29
32
  Event,
33
+ Migration,
34
+ Mutation,
35
+ MutationTargetType,
36
+ Sampling,
37
+ SamplingWithContactTracing,
30
38
  generate_trees,
31
39
  get_BD_events,
32
40
  get_BDEI_events,
33
41
  get_BDSS_events,
34
42
  get_canonical_events,
43
+ get_contact_tracing_events,
35
44
  get_epidemiological_events,
36
45
  get_FBD_events,
37
46
  simulate_tree,
@@ -59,15 +68,25 @@ __all__ = [
59
68
  "skyline_vector",
60
69
  "Tree",
61
70
  "TreeDatasetGeneratorConfig",
71
+ "Birth",
72
+ "BirthWithContactTracing",
73
+ "Death",
62
74
  "Event",
75
+ "Migration",
76
+ "Mutation",
77
+ "MutationTargetType",
78
+ "Sampling",
79
+ "SamplingWithContactTracing",
63
80
  "get_BD_events",
64
81
  "get_BDEI_events",
65
82
  "get_BDSS_events",
66
83
  "get_canonical_events",
84
+ "get_contact_tracing_events",
67
85
  "get_epidemiological_events",
68
86
  "get_FBD_events",
69
87
  "generate_trees",
70
88
  "simulate_tree",
89
+ "dump_newick",
71
90
  "load_fasta",
72
91
  "load_newick",
73
92
  "MSA",
@@ -1,12 +1,6 @@
1
- from pydantic import BaseModel, ConfigDict
2
-
3
1
  import phylogenie.typings as pgt
4
-
5
-
6
- class Distribution(BaseModel):
7
- type: str
8
- model_config = ConfigDict(extra="allow")
9
-
2
+ from phylogenie.treesimulator import MutationTargetType
3
+ from phylogenie.utils import Distribution, StrictBaseModel
10
4
 
11
5
  Integer = str | int
12
6
  Scalar = str | pgt.Scalar
@@ -15,10 +9,6 @@ OneOrManyScalars = Scalar | pgt.Many[Scalar]
15
9
  OneOrMany2DScalars = Scalar | pgt.Many2D[Scalar]
16
10
 
17
11
 
18
- class StrictBaseModel(BaseModel):
19
- model_config = ConfigDict(extra="forbid")
20
-
21
-
22
12
  class SkylineParameterModel(StrictBaseModel):
23
13
  value: ManyScalars
24
14
  change_times: ManyScalars
@@ -37,3 +27,12 @@ class SkylineMatrixModel(StrictBaseModel):
37
27
  SkylineParameter = Scalar | SkylineParameterModel
38
28
  SkylineVector = str | pgt.Scalar | pgt.Many[SkylineParameter] | SkylineVectorModel
39
29
  SkylineMatrix = str | pgt.Scalar | pgt.Many[SkylineVector] | SkylineMatrixModel | None
30
+
31
+
32
+ class Event(StrictBaseModel):
33
+ states: str | list[str] | None = None
34
+ rate: SkylineParameter
35
+
36
+
37
+ class Mutation(Event):
38
+ rate_scalers: dict[MutationTargetType, Distribution]
@@ -10,8 +10,8 @@ import pandas as pd
10
10
  from numpy.random import Generator, default_rng
11
11
  from tqdm import tqdm
12
12
 
13
- import phylogenie.generators.configs as cfg
14
- from phylogenie.generators.factories import eval_expression
13
+ from phylogenie.generators.factories import distribution
14
+ from phylogenie.utils import Distribution, StrictBaseModel
15
15
 
16
16
 
17
17
  class DataType(str, Enum):
@@ -23,12 +23,12 @@ DATA_DIRNAME = "data"
23
23
  METADATA_FILENAME = "metadata.csv"
24
24
 
25
25
 
26
- class DatasetGenerator(ABC, cfg.StrictBaseModel):
26
+ class DatasetGenerator(ABC, StrictBaseModel):
27
27
  output_dir: str = "phylogenie-outputs"
28
28
  n_samples: int | dict[str, int] = 1
29
29
  n_jobs: int = -1
30
30
  seed: int | None = None
31
- context: dict[str, cfg.Distribution] | None = None
31
+ context: dict[str, Distribution] | None = None
32
32
 
33
33
  @abstractmethod
34
34
  def _generate_one(
@@ -56,11 +56,8 @@ class DatasetGenerator(ABC, cfg.StrictBaseModel):
56
56
  data: list[dict[str, Any]] = [{} for _ in range(n_samples)]
57
57
  if self.context is not None:
58
58
  for d, (k, v) in product(data, self.context.items()):
59
- args = v.model_extra if v.model_extra is not None else {}
60
- for arg_name, arg_value in args.items():
61
- if isinstance(arg_value, str):
62
- args[arg_name] = eval_expression(arg_value, d)
63
- d[k] = np.array(getattr(rng, v.type)(**args)).tolist()
59
+ dist = distribution(v, d)
60
+ d[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
64
61
  df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
65
62
  df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
66
63
 
@@ -14,9 +14,10 @@ from phylogenie.skyline import (
14
14
  SkylineVector,
15
15
  SkylineVectorCoercible,
16
16
  )
17
+ from phylogenie.utils import Distribution
17
18
 
18
19
 
19
- def eval_expression(expression: str, data: dict[str, Any]) -> Any:
20
+ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
20
21
  return np.array(
21
22
  eval(
22
23
  expression,
@@ -29,9 +30,17 @@ def eval_expression(expression: str, data: dict[str, Any]) -> Any:
29
30
  ).tolist()
30
31
 
31
32
 
33
+ def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
34
+ args = x.args
35
+ for arg_name, arg_value in args.items():
36
+ if isinstance(arg_value, str):
37
+ args[arg_name] = _eval_expression(arg_value, data)
38
+ return Distribution(type=x.type, **args)
39
+
40
+
32
41
  def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
33
42
  if isinstance(x, str):
34
- e = eval_expression(x, data)
43
+ e = _eval_expression(x, data)
35
44
  if isinstance(e, int):
36
45
  return e
37
46
  raise ValueError(
@@ -42,7 +51,7 @@ def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
42
51
 
43
52
  def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
44
53
  if isinstance(x, str):
45
- e = eval_expression(x, data)
54
+ e = _eval_expression(x, data)
46
55
  if isinstance(e, pgt.Scalar):
47
56
  return e
48
57
  raise ValueError(
@@ -53,7 +62,7 @@ def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
53
62
 
54
63
  def many_scalars(x: cfg.ManyScalars, data: dict[str, Any]) -> pgt.ManyScalars:
55
64
  if isinstance(x, str):
56
- e = eval_expression(x, data)
65
+ e = _eval_expression(x, data)
57
66
  if tg.is_many_scalars(e):
58
67
  return e
59
68
  raise ValueError(
@@ -66,7 +75,7 @@ def one_or_many_scalars(
66
75
  x: cfg.OneOrManyScalars, data: dict[str, Any]
67
76
  ) -> pgt.OneOrManyScalars:
68
77
  if isinstance(x, str):
69
- e = eval_expression(x, data)
78
+ e = _eval_expression(x, data)
70
79
  if tg.is_one_or_many_scalars(e):
71
80
  return e
72
81
  raise ValueError(
@@ -92,7 +101,7 @@ def skyline_vector(
92
101
  x: cfg.SkylineVector, data: dict[str, Any]
93
102
  ) -> SkylineVectorCoercible:
94
103
  if isinstance(x, str):
95
- e = eval_expression(x, data)
104
+ e = _eval_expression(x, data)
96
105
  if tg.is_one_or_many_scalars(e):
97
106
  return e
98
107
  raise ValueError(
@@ -107,7 +116,7 @@ def skyline_vector(
107
116
 
108
117
  change_times = many_scalars(x.change_times, data)
109
118
  if isinstance(x.value, str):
110
- e = eval_expression(x.value, data)
119
+ e = _eval_expression(x.value, data)
111
120
  if tg.is_many_one_or_many_scalars(e):
112
121
  value = e
113
122
  else:
@@ -135,7 +144,7 @@ def one_or_many_2D_scalars(
135
144
  x: cfg.OneOrMany2DScalars, data: dict[str, Any]
136
145
  ) -> pgt.OneOrMany2DScalars:
137
146
  if isinstance(x, str):
138
- e = eval_expression(x, data)
147
+ e = _eval_expression(x, data)
139
148
  if tg.is_one_or_many_2D_scalars(e):
140
149
  return e
141
150
  raise ValueError(
@@ -153,7 +162,7 @@ def skyline_matrix(
153
162
  return None
154
163
 
155
164
  if isinstance(x, str):
156
- e = eval_expression(x, data)
165
+ e = _eval_expression(x, data)
157
166
  if tg.is_one_or_many_2D_scalars(e):
158
167
  return e
159
168
  raise ValueError(
@@ -168,7 +177,7 @@ def skyline_matrix(
168
177
 
169
178
  change_times = many_scalars(x.change_times, data)
170
179
  if isinstance(x.value, str):
171
- e = eval_expression(x.value, data)
180
+ e = _eval_expression(x.value, data)
172
181
  if tg.is_many_one_or_many_2D_scalars(e):
173
182
  value = e
174
183
  else:
@@ -9,6 +9,7 @@ from pydantic import Field
9
9
  import phylogenie.generators.configs as cfg
10
10
  from phylogenie.generators.dataset import DatasetGenerator, DataType
11
11
  from phylogenie.generators.factories import (
12
+ distribution,
12
13
  integer,
13
14
  scalar,
14
15
  skyline_matrix,
@@ -19,6 +20,7 @@ from phylogenie.io import dump_newick
19
20
  from phylogenie.tree import Tree
20
21
  from phylogenie.treesimulator import (
21
22
  Event,
23
+ Mutation,
22
24
  get_BD_events,
23
25
  get_BDEI_events,
24
26
  get_BDSS_events,
@@ -46,7 +48,6 @@ class TreeDatasetGenerator(DatasetGenerator):
46
48
  max_time: cfg.Scalar = np.inf
47
49
  init_state: str | None = None
48
50
  sampling_probability_at_present: cfg.Scalar = 0.0
49
- max_tries: int | None = None
50
51
 
51
52
  @abstractmethod
52
53
  def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
@@ -67,7 +68,6 @@ class TreeDatasetGenerator(DatasetGenerator):
67
68
  sampling_probability_at_present=scalar(
68
69
  self.sampling_probability_at_present, data
69
70
  ),
70
- max_tries=self.max_tries,
71
71
  seed=int(rng.integers(2**32)),
72
72
  )
73
73
 
@@ -127,11 +127,12 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
127
127
  )
128
128
 
129
129
 
130
- class TreeDatasetGeneratorWithContactTracing(TreeDatasetGenerator):
130
+ class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
131
131
  max_notified_contacts: cfg.Integer = 1
132
132
  notification_probability: cfg.SkylineParameter = 0.0
133
133
  sampling_rate_after_notification: cfg.SkylineParameter = np.inf
134
134
  samplable_states_after_notification: list[str] | None = None
135
+ mutations: tuple[cfg.Mutation, ...] = Field(default_factory=tuple)
135
136
 
136
137
  @abstractmethod
137
138
  def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
@@ -150,10 +151,30 @@ class TreeDatasetGeneratorWithContactTracing(TreeDatasetGenerator):
150
151
  ),
151
152
  samplable_states_after_notification=self.samplable_states_after_notification,
152
153
  )
154
+ all_states = list({e.state for e in events})
155
+ for mutation in self.mutations:
156
+ states = mutation.states
157
+ if isinstance(states, str):
158
+ states = [states]
159
+ elif states is None:
160
+ states = all_states
161
+ for state in states:
162
+ if state not in all_states:
163
+ raise ValueError(
164
+ f"Mutation state '{state}' is not found in states {all_states}."
165
+ )
166
+ rate_scalers = {
167
+ t: distribution(r, data) for t, r in mutation.rate_scalers.items()
168
+ }
169
+ events.append(
170
+ Mutation(
171
+ state, skyline_parameter(mutation.rate, data), rate_scalers
172
+ )
173
+ )
153
174
  return events
154
175
 
155
176
 
156
- class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
177
+ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
157
178
  parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
158
179
  ParameterizationType.EPIDEMIOLOGICAL
159
180
  )
@@ -179,7 +200,7 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing
179
200
  )
180
201
 
181
202
 
182
- class BDTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
203
+ class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
183
204
  parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
184
205
  reproduction_number: cfg.SkylineParameter
185
206
  infectious_period: cfg.SkylineParameter
@@ -193,7 +214,7 @@ class BDTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
193
214
  )
194
215
 
195
216
 
196
- class BDEITreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
217
+ class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
197
218
  parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
198
219
  reproduction_number: cfg.SkylineParameter
199
220
  infectious_period: cfg.SkylineParameter
@@ -209,7 +230,7 @@ class BDEITreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
209
230
  )
210
231
 
211
232
 
212
- class BDSSTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
233
+ class BDSSTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
213
234
  parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
214
235
  reproduction_number: cfg.SkylineParameter
215
236
  infectious_period: cfg.SkylineParameter
phylogenie/io.py CHANGED
@@ -24,21 +24,28 @@ def _parse_newick(newick: str) -> Tree:
24
24
  stack.append(current_nodes)
25
25
  current_nodes = []
26
26
  else:
27
- id = _parse_chars([":", ",", ")", ";"])
27
+ id = _parse_chars([":", ",", ")", ";", "["])
28
28
  branch_length = None
29
29
  if newick[i] == ":":
30
30
  i += 1
31
- branch_length = _parse_chars([",", ")", ";"])
31
+ branch_length = float(_parse_chars([",", ")", ";", "["]))
32
32
 
33
- current_node = Tree(
34
- id=id,
35
- branch_length=(None if branch_length is None else float(branch_length)),
36
- )
33
+ current_node = Tree(id, branch_length)
37
34
  for node in current_children:
38
35
  current_node.add_child(node)
39
36
  current_children = []
40
37
  current_nodes.append(current_node)
41
38
 
39
+ if newick[i] == "[":
40
+ i += 1
41
+ features = _parse_chars(["]"]).split(":")
42
+ i += 1
43
+ if features[0] != "&&NHX":
44
+ raise ValueError(f"Expected '&&NHX' for node features.")
45
+ for feature in features[1:]:
46
+ key, value = feature.split("=", 1)
47
+ current_node.set(key, eval(value))
48
+
42
49
  if newick[i] == ")":
43
50
  current_children = current_nodes
44
51
  current_nodes = stack.pop()
@@ -47,7 +54,7 @@ def _parse_newick(newick: str) -> Tree:
47
54
 
48
55
  i += 1
49
56
 
50
- raise ValueError("Newick string does not end with a semicolon.")
57
+ raise ValueError("Newick string is invalid.")
51
58
 
52
59
 
53
60
  def load_newick(filepath: str) -> Tree | list[Tree]:
@@ -63,6 +70,10 @@ def _to_newick(tree: Tree) -> str:
63
70
  newick = f"({children_newick}){newick}"
64
71
  if tree.branch_length is not None:
65
72
  newick += f":{tree.branch_length}"
73
+ if tree.features:
74
+ reprs = {k: repr(v).replace("'", '"') for k, v in tree.features.items()}
75
+ features = [f"{k}={repr}" for k, repr in reprs.items()]
76
+ newick += f"[&&NHX:{':'.join(features)}]"
66
77
  return newick
67
78
 
68
79
 
@@ -83,13 +94,14 @@ def load_fasta(
83
94
  if not line.startswith(">"):
84
95
  raise ValueError(f"Invalid FASTA format: expected '>', got '{line[0]}'")
85
96
  id = line[1:].strip()
97
+ time = None
86
98
  if extract_time_from_id is not None:
87
99
  time = extract_time_from_id(id)
88
- else:
100
+ elif "|" in id:
89
101
  try:
90
102
  time = float(id.split("|")[-1])
91
- except:
92
- time = None
103
+ except ValueError:
104
+ pass
93
105
  chars = next(f).strip()
94
106
  sequences.append(Sequence(id, chars, time))
95
107
  return MSA(sequences)
phylogenie/tree.py CHANGED
@@ -1,18 +1,41 @@
1
1
  from collections.abc import Iterator
2
+ from typing import Any
2
3
 
3
4
 
4
5
  class Tree:
5
6
  def __init__(self, id: str = "", branch_length: float | None = None):
6
7
  self.id = id
7
8
  self.branch_length = branch_length
8
- self.parent: Tree | None = None
9
- self.children: list[Tree] = []
9
+ self._parent: Tree | None = None
10
+ self._children: list[Tree] = []
11
+ self._features: dict[str, Any] = {}
12
+
13
+ @property
14
+ def children(self) -> tuple["Tree", ...]:
15
+ return tuple(self._children)
16
+
17
+ @property
18
+ def parent(self) -> "Tree | None":
19
+ return self._parent
20
+
21
+ @property
22
+ def features(self) -> dict[str, Any]:
23
+ return self._features
10
24
 
11
25
  def add_child(self, child: "Tree") -> "Tree":
12
- child.parent = self
13
- self.children.append(child)
26
+ child._parent = self
27
+ self._children.append(child)
14
28
  return self
15
29
 
30
+ def remove_child(self, child: "Tree") -> None:
31
+ self._children.remove(child)
32
+ child._parent = None
33
+
34
+ def set_parent(self, node: "Tree | None"):
35
+ self._parent = node
36
+ if node is not None:
37
+ node._children.append(self)
38
+
16
39
  def preorder_traversal(self) -> Iterator["Tree"]:
17
40
  yield self
18
41
  for child in self.children:
@@ -29,6 +52,9 @@ class Tree:
29
52
  return node
30
53
  raise ValueError(f"Node with id {id} not found.")
31
54
 
55
+ def is_leaf(self) -> bool:
56
+ return not self.children
57
+
32
58
  def get_leaves(self) -> list["Tree"]:
33
59
  return [node for node in self if not node.children]
34
60
 
@@ -38,11 +64,19 @@ class Tree:
38
64
  raise ValueError(f"Branch length of node {self.id} is not set.")
39
65
  return self.branch_length + parent_time
40
66
 
41
- def is_leaf(self) -> bool:
42
- return not self.children
67
+ def set(self, key: str, value: Any) -> None:
68
+ self._features[key] = value
69
+
70
+ def get(self, key: str) -> Any:
71
+ return self._features.get(key)
72
+
73
+ def delete(self, key: str) -> None:
74
+ del self._features[key]
43
75
 
44
- def copy(self) -> "Tree":
76
+ def copy(self):
45
77
  new_tree = Tree(self.id, self.branch_length)
78
+ for key, value in self._features.items():
79
+ new_tree.set(key, value)
46
80
  for child in self.children:
47
81
  new_tree.add_child(child.copy())
48
82
  return new_tree
@@ -51,4 +85,4 @@ class Tree:
51
85
  return self.preorder_traversal()
52
86
 
53
87
  def __repr__(self) -> str:
54
- return f"TreeNode(id='{self.id}', branch_length={self.branch_length})"
88
+ return f"TreeNode(id='{self.id}', branch_length={self.branch_length}, features={self.features})"
@@ -1,5 +1,13 @@
1
1
  from phylogenie.treesimulator.events import (
2
+ Birth,
3
+ BirthWithContactTracing,
4
+ Death,
2
5
  Event,
6
+ Migration,
7
+ Mutation,
8
+ MutationTargetType,
9
+ Sampling,
10
+ SamplingWithContactTracing,
3
11
  get_BD_events,
4
12
  get_BDEI_events,
5
13
  get_BDSS_events,
@@ -11,7 +19,15 @@ from phylogenie.treesimulator.events import (
11
19
  from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
12
20
 
13
21
  __all__ = [
22
+ "Birth",
23
+ "BirthWithContactTracing",
24
+ "Death",
14
25
  "Event",
26
+ "Migration",
27
+ "Mutation",
28
+ "MutationTargetType",
29
+ "Sampling",
30
+ "SamplingWithContactTracing",
15
31
  "get_BD_events",
16
32
  "get_BDEI_events",
17
33
  "get_BDSS_events",
@@ -0,0 +1,39 @@
1
+ from phylogenie.treesimulator.events.contact_tracing import (
2
+ BirthWithContactTracing,
3
+ SamplingWithContactTracing,
4
+ get_contact_tracing_events,
5
+ )
6
+ from phylogenie.treesimulator.events.core import (
7
+ Birth,
8
+ Death,
9
+ Event,
10
+ Migration,
11
+ Sampling,
12
+ get_BD_events,
13
+ get_BDEI_events,
14
+ get_BDSS_events,
15
+ get_canonical_events,
16
+ get_epidemiological_events,
17
+ get_FBD_events,
18
+ )
19
+ from phylogenie.treesimulator.events.mutations import Mutation
20
+ from phylogenie.treesimulator.events.mutations import TargetType as MutationTargetType
21
+
22
+ __all__ = [
23
+ "Birth",
24
+ "BirthWithContactTracing",
25
+ "Death",
26
+ "Event",
27
+ "Migration",
28
+ "Mutation",
29
+ "Sampling",
30
+ "SamplingWithContactTracing",
31
+ "MutationTargetType",
32
+ "get_BD_events",
33
+ "get_BDEI_events",
34
+ "get_BDSS_events",
35
+ "get_canonical_events",
36
+ "get_contact_tracing_events",
37
+ "get_epidemiological_events",
38
+ "get_FBD_events",
39
+ ]
@@ -0,0 +1,125 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Sequence
3
+
4
+ import numpy as np
5
+ from numpy.random import Generator
6
+
7
+ from phylogenie.skyline import SkylineParameterLike, skyline_parameter
8
+ from phylogenie.treesimulator.events.core import Birth, Death, Migration, Sampling
9
+ from phylogenie.treesimulator.model import Event, Model
10
+
11
+ CT_POSTFIX = "-CT"
12
+ CONTACTS_KEY = "CONTACTS"
13
+
14
+
15
+ def _get_CT_state(state: str) -> str:
16
+ return f"{state}{CT_POSTFIX}"
17
+
18
+
19
+ def _is_CT_state(state: str) -> bool:
20
+ return state.endswith(CT_POSTFIX)
21
+
22
+
23
+ class BirthWithContactTracing(Event):
24
+ def __init__(self, state: str, rate: SkylineParameterLike, child_state: str):
25
+ super().__init__(state, rate)
26
+ self.child_state = child_state
27
+
28
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
29
+ individual = self.draw_individual(model, rng)
30
+ new_individual = model.birth_from(individual, self.child_state, time)
31
+ if CONTACTS_KEY not in model.context:
32
+ model.context[CONTACTS_KEY] = defaultdict(list)
33
+ model.context[CONTACTS_KEY][individual].append(new_individual)
34
+ model.context[CONTACTS_KEY][new_individual].append(individual)
35
+
36
+ def __repr__(self) -> str:
37
+ return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
38
+
39
+
40
+ class SamplingWithContactTracing(Event):
41
+ def __init__(
42
+ self,
43
+ state: str,
44
+ rate: SkylineParameterLike,
45
+ max_notified_contacts: int,
46
+ notification_probability: SkylineParameterLike,
47
+ ):
48
+ super().__init__(state, rate)
49
+ self.max_notified_contacts = max_notified_contacts
50
+ self.notification_probability = skyline_parameter(notification_probability)
51
+
52
+ def apply(self, model: Model, time: float, rng: Generator) -> None:
53
+ individual = self.draw_individual(model, rng)
54
+ model.sample(individual, time, True)
55
+ population = model.get_population()
56
+ if CONTACTS_KEY not in model.context:
57
+ return
58
+ contacts = model.context[CONTACTS_KEY][individual]
59
+ for contact in contacts[-self.max_notified_contacts :]:
60
+ if contact in population:
61
+ state = model.get_state(contact)
62
+ p = self.notification_probability.get_value_at_time(time)
63
+ if not _is_CT_state(state) and rng.random() < p:
64
+ model.migrate(contact, _get_CT_state(state), time)
65
+
66
+ def __repr__(self) -> str:
67
+ return f"SamplingWithContactTracing(state={self.state}, rate={self.rate}, max_notified_contacts={self.max_notified_contacts}, notification_probability={self.notification_probability})"
68
+
69
+
70
+ def get_contact_tracing_events(
71
+ events: Sequence[Event],
72
+ max_notified_contacts: int = 1,
73
+ notification_probability: SkylineParameterLike = 1,
74
+ sampling_rate_after_notification: SkylineParameterLike = np.inf,
75
+ samplable_states_after_notification: list[str] | None = None,
76
+ ) -> list[Event]:
77
+ ct_events: list[Event] = []
78
+ notification_probability = skyline_parameter(notification_probability)
79
+ sampling_rate_after_notification = skyline_parameter(
80
+ sampling_rate_after_notification
81
+ )
82
+ for event in events:
83
+ state, rate = event.state, event.rate
84
+ if isinstance(event, Migration):
85
+ ct_events.append(event)
86
+ ct_events.append(
87
+ Migration(_get_CT_state(state), rate, _get_CT_state(event.target_state))
88
+ )
89
+ elif isinstance(event, Birth):
90
+ ct_events.append(BirthWithContactTracing(state, rate, event.child_state))
91
+ ct_events.append(
92
+ BirthWithContactTracing(_get_CT_state(state), rate, event.child_state)
93
+ )
94
+ elif isinstance(event, Sampling):
95
+ if not event.removal:
96
+ raise ValueError(
97
+ "Contact tracing requires removal to be set for all sampling events."
98
+ )
99
+ ct_events.append(
100
+ SamplingWithContactTracing(
101
+ state, rate, max_notified_contacts, notification_probability
102
+ )
103
+ )
104
+ elif isinstance(event, Death):
105
+ ct_events.append(event)
106
+ else:
107
+ raise NotImplementedError(
108
+ f"Unsupported event type {type(event)} for contact tracing."
109
+ )
110
+
111
+ for state in (
112
+ samplable_states_after_notification
113
+ if samplable_states_after_notification is not None
114
+ else {e.state for e in events}
115
+ ):
116
+ ct_events.append(
117
+ SamplingWithContactTracing(
118
+ _get_CT_state(state),
119
+ sampling_rate_after_notification,
120
+ max_notified_contacts,
121
+ notification_probability,
122
+ )
123
+ )
124
+
125
+ return ct_events