phylogenie 2.1.30__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
phylogenie/__init__.py CHANGED
@@ -11,7 +11,7 @@ from phylogenie.generators import (
11
11
  FBDTreeDatasetGenerator,
12
12
  TreeDatasetGeneratorConfig,
13
13
  )
14
- from phylogenie.io import dump_newick, load_fasta, load_newick, load_nexus
14
+ from phylogenie.io import load_fasta
15
15
  from phylogenie.msa import MSA
16
16
  from phylogenie.skyline import (
17
17
  SkylineMatrix,
@@ -25,41 +25,47 @@ from phylogenie.skyline import (
25
25
  skyline_parameter,
26
26
  skyline_vector,
27
27
  )
28
- from phylogenie.tree import Tree
29
28
  from phylogenie.treesimulator import (
30
29
  Birth,
31
30
  BirthWithContactTracing,
32
31
  Death,
33
32
  Event,
34
33
  EventType,
34
+ Feature,
35
35
  Migration,
36
36
  Mutation,
37
37
  Sampling,
38
38
  SamplingWithContactTracing,
39
+ Tree,
40
+ compute_mean_leaf_pairwise_distance,
41
+ compute_sackin_index,
42
+ dump_newick,
39
43
  generate_trees,
40
44
  get_BD_events,
41
45
  get_BDEI_events,
42
46
  get_BDSS_events,
43
47
  get_canonical_events,
44
48
  get_contact_tracing_events,
49
+ get_distance,
45
50
  get_epidemiological_events,
46
51
  get_FBD_events,
47
- simulate_tree,
48
- )
49
- from phylogenie.utils import (
50
- compute_colless_index,
51
- compute_mean_leaf_pairwise_distance,
52
- compute_sackin_index,
53
- get_distance,
54
52
  get_mrca,
53
+ get_mutation_id,
55
54
  get_node_depth_levels,
56
55
  get_node_depths,
57
56
  get_node_height_levels,
58
57
  get_node_heights,
59
58
  get_node_leaf_counts,
59
+ get_node_state,
60
+ load_newick,
61
+ load_nexus,
62
+ set_features,
63
+ simulate_tree,
60
64
  )
61
65
 
62
66
  __all__ = [
67
+ "Coloring",
68
+ "draw_tree",
63
69
  "AliSimDatasetGenerator",
64
70
  "BDEITreeDatasetGenerator",
65
71
  "BDSSTreeDatasetGenerator",
@@ -69,51 +75,52 @@ __all__ = [
69
75
  "DatasetGeneratorConfig",
70
76
  "EpidemiologicalTreeDatasetGenerator",
71
77
  "FBDTreeDatasetGenerator",
78
+ "TreeDatasetGeneratorConfig",
79
+ "load_fasta",
80
+ "MSA",
72
81
  "SkylineMatrix",
73
82
  "SkylineMatrixCoercible",
74
- "skyline_matrix",
75
83
  "SkylineParameter",
76
84
  "SkylineParameterLike",
77
- "skyline_parameter",
78
85
  "SkylineVector",
79
86
  "SkylineVectorCoercible",
80
87
  "SkylineVectorLike",
88
+ "skyline_matrix",
89
+ "skyline_parameter",
81
90
  "skyline_vector",
82
- "Tree",
83
- "TreeDatasetGeneratorConfig",
84
91
  "Birth",
85
92
  "BirthWithContactTracing",
86
93
  "Death",
87
94
  "Event",
88
95
  "EventType",
96
+ "Feature",
89
97
  "Migration",
90
98
  "Mutation",
91
99
  "Sampling",
92
100
  "SamplingWithContactTracing",
101
+ "Tree",
102
+ "compute_mean_leaf_pairwise_distance",
103
+ "compute_sackin_index",
104
+ "dump_newick",
105
+ "generate_trees",
93
106
  "get_BD_events",
94
107
  "get_BDEI_events",
95
108
  "get_BDSS_events",
96
109
  "get_canonical_events",
97
110
  "get_contact_tracing_events",
111
+ "get_distance",
98
112
  "get_epidemiological_events",
99
113
  "get_FBD_events",
100
- "generate_trees",
101
- "simulate_tree",
102
- "dump_newick",
103
- "load_nexus",
104
- "load_fasta",
105
- "load_newick",
106
- "MSA",
107
- "Coloring",
108
- "draw_tree",
109
- "compute_colless_index",
110
- "compute_mean_leaf_pairwise_distance",
111
- "compute_sackin_index",
112
- "get_distance",
113
114
  "get_mrca",
114
- "get_node_depths",
115
+ "get_mutation_id",
115
116
  "get_node_depth_levels",
116
- "get_node_heights",
117
+ "get_node_depths",
117
118
  "get_node_height_levels",
119
+ "get_node_heights",
118
120
  "get_node_leaf_counts",
121
+ "get_node_state",
122
+ "load_newick",
123
+ "load_nexus",
124
+ "set_features",
125
+ "simulate_tree",
119
126
  ]
phylogenie/draw.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from enum import Enum
2
- from itertools import islice
3
- from typing import Any
2
+ from typing import Any, Callable
4
3
 
5
4
  import matplotlib.colors as mcolors
6
5
  import matplotlib.patches as mpatches
@@ -8,8 +7,7 @@ import matplotlib.pyplot as plt
8
7
  from matplotlib.axes import Axes
9
8
  from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
10
9
 
11
- from phylogenie.tree import Tree
12
- from phylogenie.utils import get_node_depth_levels, get_node_depths
10
+ from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
13
11
 
14
12
 
15
13
  class Coloring(str, Enum):
@@ -26,7 +24,7 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
26
24
 
27
25
  xs = (
28
26
  get_node_depth_levels(tree)
29
- if any(node.branch_length is None for node in islice(tree, 1, None))
27
+ if any(node.branch_length is None for node in tree.iter_descendants())
30
28
  else get_node_depths(tree)
31
29
  )
32
30
  ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
@@ -50,7 +48,7 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
50
48
  def draw_tree(
51
49
  tree: Tree,
52
50
  ax: Axes | None = None,
53
- color_by: str | None = None,
51
+ color_by: str | dict[Tree, Any] | None = None,
54
52
  coloring: str | Coloring | None = None,
55
53
  default_color: Color = "black",
56
54
  cmap: str | None = None,
@@ -69,33 +67,35 @@ def draw_tree(
69
67
  if color_by is None:
70
68
  return _draw_colored_tree(tree, ax, colors=default_color)
71
69
 
72
- features = [node.get(color_by) for node in tree if color_by in node.features]
70
+ if isinstance(color_by, dict):
71
+ features = {node: color_by[node] for node in tree if node in color_by}
72
+ else:
73
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
73
74
 
74
75
  if coloring is None:
75
76
  coloring = (
76
77
  Coloring.CONTINUOUS
77
- if any(isinstance(f, float) for f in features)
78
+ if any(isinstance(f, float) for f in features.values())
78
79
  else Coloring.DISCRETE
79
80
  )
80
81
 
82
+ def _get_colors(feature_map: Callable[[Any], Color]) -> dict[Tree, Color]:
83
+ return {
84
+ node: feature_map(features[node]) if node in features else default_color
85
+ for node in tree
86
+ }
87
+
81
88
  if coloring == Coloring.DISCRETE:
82
- if any(isinstance(f, float) for f in features):
89
+ if any(isinstance(f, float) for f in features.values()):
83
90
  raise ValueError(
84
91
  "Discrete coloring selected but feature values are not all categorical."
85
92
  )
86
93
 
87
94
  colormap = plt.get_cmap("tab20" if cmap is None else cmap)
88
95
  feature_colors = {
89
- f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features))
90
- }
91
- colors = {
92
- node: (
93
- feature_colors[node.get(color_by)]
94
- if color_by in node.features
95
- else default_color
96
- )
97
- for node in tree
96
+ f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
98
97
  }
98
+ colors = _get_colors(lambda f: feature_colors[f])
99
99
 
100
100
  if show_legend:
101
101
  legend_handles = [
@@ -105,7 +105,7 @@ def draw_tree(
105
105
  )
106
106
  for f in feature_colors
107
107
  ]
108
- if any(color_by not in node.features for node in tree):
108
+ if any(color_by not in node.metadata for node in tree):
109
109
  legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
110
110
  if legend_kwargs is None:
111
111
  legend_kwargs = {}
@@ -114,18 +114,11 @@ def draw_tree(
114
114
  return _draw_colored_tree(tree, ax, colors)
115
115
 
116
116
  if coloring == Coloring.CONTINUOUS:
117
- vmin = min(features) if vmin is None else vmin
118
- vmax = max(features) if vmax is None else vmax
117
+ vmin = min(features.values()) if vmin is None else vmin
118
+ vmax = max(features.values()) if vmax is None else vmax
119
119
  norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
120
120
  colormap = plt.get_cmap("viridis" if cmap is None else cmap)
121
- colors = {
122
- node: (
123
- colormap(norm(float(node.get(color_by))))
124
- if color_by in node.features
125
- else default_color
126
- )
127
- for node in tree
128
- }
121
+ colors = _get_colors(lambda f: colormap(norm(float(f))))
129
122
 
130
123
  if show_hist:
131
124
  default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
@@ -8,8 +8,7 @@ from numpy.random import Generator, default_rng
8
8
  from phylogenie.generators.dataset import DatasetGenerator, DataType
9
9
  from phylogenie.generators.factories import data, string
10
10
  from phylogenie.generators.trees import TreeDatasetGeneratorConfig
11
- from phylogenie.io import dump_newick
12
- from phylogenie.utils import get_node_depths
11
+ from phylogenie.treesimulator import dump_newick, get_node_depths
13
12
 
14
13
  MSAS_DIRNAME = "MSAs"
15
14
  TREES_DIRNAME = "trees"
@@ -60,26 +59,26 @@ class AliSimDatasetGenerator(DatasetGenerator):
60
59
  tree_filename = f"{filename}.temp-tree"
61
60
  msa_filename = filename
62
61
 
63
- d: dict[str, Any] = {"file_id": Path(msa_filename).stem}
62
+ md: dict[str, Any] = {"file_id": Path(msa_filename).stem}
64
63
  rng = default_rng(seed)
65
64
  while True:
66
- d.update(data(context, rng))
65
+ md.update(data(context, rng))
67
66
  try:
68
- tree, metadata = self.trees.simulate_one(d, seed)
67
+ tree, metadata = self.trees.simulate_one(md, seed)
69
68
  break
70
69
  except TimeoutError:
71
70
  print(
72
71
  "Tree simulation timed out, retrying with different parameters..."
73
72
  )
74
- d.update(metadata)
73
+ md.update(metadata)
75
74
 
76
75
  times = get_node_depths(tree)
77
76
  for leaf in tree.get_leaves():
78
77
  leaf.name += f"|{times[leaf]}"
79
78
  dump_newick(tree, f"{tree_filename}.nwk")
80
79
 
81
- self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, d)
80
+ self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, md)
82
81
  if not self.keep_trees:
83
82
  os.remove(f"{tree_filename}.nwk")
84
83
 
85
- return d
84
+ return md
@@ -1,7 +1,29 @@
1
+ from typing import Any
2
+
3
+ from numpy.random import Generator
4
+ from pydantic import BaseModel, ConfigDict
5
+
1
6
  import phylogenie.typings as pgt
2
- from phylogenie.models import Distribution, StrictBaseModel
3
7
  from phylogenie.treesimulator import EventType
4
8
 
9
+
10
+ class StrictBaseModel(BaseModel):
11
+ model_config = ConfigDict(extra="forbid")
12
+
13
+
14
+ class Distribution(BaseModel):
15
+ type: str
16
+ model_config = ConfigDict(extra="allow")
17
+
18
+ @property
19
+ def args(self) -> dict[str, Any]:
20
+ assert self.model_extra is not None
21
+ return self.model_extra
22
+
23
+ def __call__(self, rng: Generator) -> Any:
24
+ return getattr(rng, self.type)(**self.args)
25
+
26
+
5
27
  Integer = str | int
6
28
  Scalar = str | pgt.Scalar
7
29
  ManyScalars = str | pgt.Many[Scalar]
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  from numpy.random import Generator, default_rng
9
9
  from tqdm import tqdm
10
10
 
11
- from phylogenie.models import Distribution, StrictBaseModel
11
+ from phylogenie.generators.configs import Distribution, StrictBaseModel
12
12
 
13
13
 
14
14
  class DataType(str, Enum):
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Any
2
+ from typing import Any, Callable
3
3
 
4
4
  import numpy as np
5
5
  from numpy.random import Generator
@@ -8,7 +8,6 @@ import phylogenie.generators.configs as cfg
8
8
  import phylogenie.generators.typeguards as ctg
9
9
  import phylogenie.typeguards as tg
10
10
  import phylogenie.typings as pgt
11
- from phylogenie.models import Distribution
12
11
  from phylogenie.skyline import (
13
12
  SkylineMatrix,
14
13
  SkylineMatrixCoercible,
@@ -24,11 +23,7 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
24
23
  return np.array(
25
24
  eval(
26
25
  expression,
27
- {
28
- "__builtins__": __builtins__,
29
- "np": np,
30
- **{k: np.array(v) for k, v in data.items()},
31
- },
26
+ {"np": np, **{k: np.array(v) for k, v in data.items()}},
32
27
  )
33
28
  ).tolist()
34
29
 
@@ -214,12 +209,12 @@ def skyline_matrix(
214
209
  return SkylineMatrix(value=value, change_times=change_times)
215
210
 
216
211
 
217
- def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
212
+ def distribution(x: cfg.Distribution, data: dict[str, Any]) -> cfg.Distribution:
218
213
  args = x.args
219
214
  for arg_name, arg_value in args.items():
220
215
  if isinstance(arg_value, str):
221
216
  args[arg_name] = _eval_expression(arg_value, data)
222
- return Distribution(type=x.type, **args)
217
+ return cfg.Distribution(type=x.type, **args)
223
218
 
224
219
 
225
220
  def mutations(
@@ -227,11 +222,14 @@ def mutations(
227
222
  data: dict[str, Any],
228
223
  states: set[str],
229
224
  rates_to_log: list[EventType] | None,
225
+ rng: Generator,
230
226
  ) -> list[Mutation]:
231
227
  mutations: list[Mutation] = []
232
228
  for m in x:
233
229
  rate = skyline_parameter(m.rate, data)
234
- rate_scalers = {k: distribution(v, data) for k, v in m.rate_scalers.items()}
230
+ rate_scalers: dict[EventType, Callable[[], float]] = {
231
+ k: lambda: distribution(v, data)(rng) for k, v in m.rate_scalers.items()
232
+ }
235
233
  if m.state is None:
236
234
  mutations.extend(
237
235
  Mutation(s, rate, rate_scalers, rates_to_log) for s in states
@@ -241,11 +239,10 @@ def mutations(
241
239
  return mutations
242
240
 
243
241
 
244
- def data(context: dict[str, Distribution] | None, rng: Generator) -> dict[str, Any]:
242
+ def data(context: dict[str, cfg.Distribution] | None, rng: Generator) -> dict[str, Any]:
245
243
  if context is None:
246
244
  return {}
247
245
  data: dict[str, Any] = {}
248
246
  for k, v in context.items():
249
- dist = distribution(v, data)
250
- data[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
247
+ data[k] = np.array(distribution(v, data)(rng)).tolist()
251
248
  return data
@@ -1,13 +1,14 @@
1
1
  from abc import abstractmethod
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Annotated, Any, Literal
4
+ from typing import Annotated, Any, Callable, Literal
5
5
 
6
6
  import numpy as np
7
7
  from numpy.random import default_rng
8
8
  from pydantic import Field
9
9
 
10
10
  import phylogenie.generators.configs as cfg
11
+ from phylogenie.generators.configs import Distribution
11
12
  from phylogenie.generators.dataset import DatasetGenerator, DataType
12
13
  from phylogenie.generators.factories import (
13
14
  data,
@@ -18,13 +19,12 @@ from phylogenie.generators.factories import (
18
19
  skyline_parameter,
19
20
  skyline_vector,
20
21
  )
21
- from phylogenie.io import dump_newick
22
- from phylogenie.models import Distribution
23
- from phylogenie.tree import Tree
24
22
  from phylogenie.treesimulator import (
25
23
  Event,
26
24
  EventType,
27
25
  Feature,
26
+ Tree,
27
+ dump_newick,
28
28
  get_BD_events,
29
29
  get_BDEI_events,
30
30
  get_BDSS_events,
@@ -50,13 +50,13 @@ class TreeDatasetGenerator(DatasetGenerator):
50
50
  data_type: Literal[DataType.TREES] = DataType.TREES
51
51
  mutations: list[cfg.Mutation] = Field(default_factory=lambda: [])
52
52
  rates_to_log: list[EventType] | None = None
53
- min_tips: cfg.Integer = 1
54
- max_tips: cfg.Integer | None = None
53
+ n_tips: cfg.Integer | None = None
55
54
  max_time: cfg.Scalar = np.inf
56
55
  init_state: str | None = None
57
56
  sampling_probability_at_present: cfg.Scalar = 0.0
58
57
  timeout: float = np.inf
59
58
  node_features: list[Feature] | None = None
59
+ acceptance_criterion: str | None = None
60
60
 
61
61
  @abstractmethod
62
62
  def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
@@ -71,11 +71,19 @@ class TreeDatasetGenerator(DatasetGenerator):
71
71
  )
72
72
  events = self._get_events(data)
73
73
  states = {e.state for e in events}
74
- events += mutations(self.mutations, data, states, self.rates_to_log)
74
+ events += mutations(
75
+ self.mutations, data, states, self.rates_to_log, default_rng(seed)
76
+ )
77
+ acceptance_criterion: None | Callable[[Tree], bool] = (
78
+ None
79
+ if self.acceptance_criterion is None
80
+ else lambda tree: eval(
81
+ self.acceptance_criterion, {}, {"tree": tree} # pyright: ignore
82
+ )
83
+ )
75
84
  return simulate_tree(
76
85
  events=events,
77
- min_tips=integer(self.min_tips, data),
78
- max_tips=None if self.max_tips is None else integer(self.max_tips, data),
86
+ n_tips=None if self.n_tips is None else integer(self.n_tips, data),
79
87
  max_time=scalar(self.max_time, data),
80
88
  init_state=init_state,
81
89
  sampling_probability_at_present=scalar(
@@ -83,6 +91,7 @@ class TreeDatasetGenerator(DatasetGenerator):
83
91
  ),
84
92
  seed=seed,
85
93
  timeout=self.timeout,
94
+ acceptance_criterion=acceptance_criterion,
86
95
  )
87
96
 
88
97
  def generate_one(
@@ -157,7 +166,7 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
157
166
  class ContactTracingTreeDatasetGenerator(TreeDatasetGenerator):
158
167
  max_notified_contacts: cfg.Integer = 1
159
168
  notification_probability: cfg.SkylineParameter = 0.0
160
- sampling_rate_after_notification: cfg.SkylineParameter = np.inf
169
+ sampling_rate_after_notification: cfg.SkylineParameter = 2**32
161
170
  samplable_states_after_notification: list[str] | None = None
162
171
 
163
172
  @abstractmethod
phylogenie/io/__init__.py CHANGED
@@ -1,5 +1,3 @@
1
1
  from phylogenie.io.fasta import load_fasta
2
- from phylogenie.io.newick import dump_newick, load_newick
3
- from phylogenie.io.nexus import load_nexus
4
2
 
5
- __all__ = ["load_fasta", "load_newick", "dump_newick", "load_nexus"]
3
+ __all__ = ["load_fasta"]
phylogenie/mixins.py ADDED
@@ -0,0 +1,41 @@
1
+ from types import MappingProxyType
2
+ from typing import Any, Mapping, Optional
3
+
4
+
5
+ class MetadataMixin:
6
+ """A mixin that provides metadata management with dictionary-like access."""
7
+
8
+ def __init__(self) -> None:
9
+ self._metadata: dict[str, Any] = {}
10
+
11
+ @property
12
+ def metadata(self) -> Mapping[str, Any]:
13
+ """Return a read-only view of all metadata."""
14
+ return MappingProxyType(self._metadata)
15
+
16
+ def set(self, key: str, value: Any) -> None:
17
+ """Set or update a metadata value."""
18
+ self._metadata[key] = value
19
+
20
+ def update(self, metadata: Mapping[str, Any]) -> None:
21
+ """Bulk update metadata values."""
22
+ self._metadata.update(metadata)
23
+
24
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
25
+ """Get a metadata value, returning `default` if not found."""
26
+ return self._metadata.get(key, default)
27
+
28
+ def delete(self, key: str) -> None:
29
+ """Delete a metadata if it exists, else do nothing."""
30
+ self._metadata.pop(key, None)
31
+
32
+ def clear(self) -> None:
33
+ """Remove all metadata."""
34
+ self._metadata.clear()
35
+
36
+ # Dict-like behavior
37
+ def __getitem__(self, key: str) -> Any:
38
+ return self._metadata[key]
39
+
40
+ def __setitem__(self, key: str, value: Any) -> None:
41
+ self._metadata[key] = value
@@ -19,7 +19,20 @@ from phylogenie.treesimulator.events import (
19
19
  )
20
20
  from phylogenie.treesimulator.features import Feature, set_features
21
21
  from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
22
+ from phylogenie.treesimulator.io import dump_newick, load_newick, load_nexus
22
23
  from phylogenie.treesimulator.model import get_node_state
24
+ from phylogenie.treesimulator.tree import Tree
25
+ from phylogenie.treesimulator.utils import (
26
+ compute_mean_leaf_pairwise_distance,
27
+ compute_sackin_index,
28
+ get_distance,
29
+ get_mrca,
30
+ get_node_depth_levels,
31
+ get_node_depths,
32
+ get_node_height_levels,
33
+ get_node_heights,
34
+ get_node_leaf_counts,
35
+ )
23
36
 
24
37
  __all__ = [
25
38
  "Birth",
@@ -38,10 +51,23 @@ __all__ = [
38
51
  "get_contact_tracing_events",
39
52
  "get_epidemiological_events",
40
53
  "get_FBD_events",
41
- "generate_trees",
42
- "simulate_tree",
43
54
  "get_mutation_id",
44
- "get_node_state",
45
55
  "Feature",
46
56
  "set_features",
57
+ "simulate_tree",
58
+ "dump_newick",
59
+ "load_newick",
60
+ "load_nexus",
61
+ "generate_trees",
62
+ "get_node_state",
63
+ "Tree",
64
+ "compute_mean_leaf_pairwise_distance",
65
+ "compute_sackin_index",
66
+ "get_distance",
67
+ "get_mrca",
68
+ "get_node_depth_levels",
69
+ "get_node_depths",
70
+ "get_node_height_levels",
71
+ "get_node_heights",
72
+ "get_node_leaf_counts",
47
73
  ]
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
2
2
  from enum import Enum
3
3
  from typing import Any
4
4
 
5
- import numpy as np
6
5
  from numpy.random import Generator
7
6
 
8
7
  from phylogenie.skyline import SkylineParameterLike, skyline_parameter
@@ -32,8 +31,6 @@ class Event(ABC):
32
31
  def get_propensity(self, model: Model, time: float) -> float:
33
32
  n_individuals = model.count_individuals(self.state)
34
33
  rate = self.rate.get_value_at_time(time)
35
- if rate == np.inf and not n_individuals:
36
- return 0
37
34
  return rate * n_individuals
38
35
 
39
36
  @abstractmethod
@@ -2,7 +2,6 @@ from collections import defaultdict
2
2
  from collections.abc import Sequence
3
3
  from copy import deepcopy
4
4
 
5
- import numpy as np
6
5
  from numpy.random import Generator
7
6
 
8
7
  from phylogenie.skyline import SkylineParameterLike, skyline_parameter
@@ -32,10 +31,10 @@ class BirthWithContactTracing(Event):
32
31
  def apply(self, model: Model, events: list[Event], time: float, rng: Generator):
33
32
  individual = self.draw_individual(model, rng)
34
33
  new_individual = model.birth_from(individual, self.child_state, time)
35
- if CONTACTS_KEY not in model.context:
36
- model.context[CONTACTS_KEY] = defaultdict(list)
37
- model.context[CONTACTS_KEY][individual].append(new_individual)
38
- model.context[CONTACTS_KEY][new_individual].append(individual)
34
+ if CONTACTS_KEY not in model.metadata:
35
+ model[CONTACTS_KEY] = defaultdict(list)
36
+ model[CONTACTS_KEY][individual].append(new_individual)
37
+ model[CONTACTS_KEY][new_individual].append(individual)
39
38
 
40
39
  def __repr__(self) -> str:
41
40
  return f"BirthWithContactTracing(state={self.state}, rate={self.rate}, child_state={self.child_state})"
@@ -59,9 +58,9 @@ class SamplingWithContactTracing(Event):
59
58
  individual = self.draw_individual(model, rng)
60
59
  model.sample(individual, time, True)
61
60
  population = model.get_population()
62
- if CONTACTS_KEY not in model.context:
61
+ if CONTACTS_KEY not in model.metadata:
63
62
  return
64
- contacts = model.context[CONTACTS_KEY][individual]
63
+ contacts = model[CONTACTS_KEY][individual]
65
64
  for contact in contacts[-self.max_notified_contacts :]:
66
65
  if contact in population:
67
66
  state = model.get_state(contact)
@@ -76,8 +75,8 @@ class SamplingWithContactTracing(Event):
76
75
  def get_contact_tracing_events(
77
76
  events: Sequence[Event],
78
77
  max_notified_contacts: int = 1,
79
- notification_probability: SkylineParameterLike = 1,
80
- sampling_rate_after_notification: SkylineParameterLike = np.inf,
78
+ notification_probability: SkylineParameterLike = 0.0,
79
+ sampling_rate_after_notification: SkylineParameterLike = 2**32,
81
80
  samplable_states_after_notification: list[str] | None = None,
82
81
  ) -> list[Event]:
83
82
  ct_events: list[Event] = []