phylogenie 2.0.5__py3-none-any.whl → 2.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of phylogenie might be problematic. Click here for more details.

@@ -46,11 +46,12 @@ class AliSimDatasetGenerator(DatasetGenerator):
46
46
  self, filename: str, rng: Generator, data: dict[str, Any]
47
47
  ) -> None:
48
48
  if self.keep_trees:
49
- base_dir = Path(filename).parent
50
- file_id = Path(filename).stem
51
- tree_filename = os.path.join(base_dir, TREES_DIRNAME, file_id)
49
+ base_dir, file_id = Path(filename).parent, Path(filename).stem
50
+ trees_dir = os.path.join(base_dir, TREES_DIRNAME)
52
51
  msas_dir = os.path.join(base_dir, MSAS_DIRNAME)
52
+ os.makedirs(trees_dir, exist_ok=True)
53
53
  os.makedirs(msas_dir, exist_ok=True)
54
+ tree_filename = os.path.join(trees_dir, file_id)
54
55
  msa_filename = os.path.join(msas_dir, file_id)
55
56
  else:
56
57
  tree_filename = f"{filename}.temp-tree"
@@ -3,16 +3,16 @@ from pydantic import BaseModel, ConfigDict
3
3
  import phylogenie.typings as pgt
4
4
 
5
5
 
6
- class DistributionConfig(BaseModel):
6
+ class Distribution(BaseModel):
7
7
  type: str
8
8
  model_config = ConfigDict(extra="allow")
9
9
 
10
10
 
11
- IntegerConfig = str | int
12
- ScalarConfig = str | pgt.Scalar
13
- ManyScalarsConfig = str | list[ScalarConfig]
14
- OneOrManyScalarsConfig = ScalarConfig | list[ScalarConfig]
15
- OneOrMany2DScalarsConfig = ScalarConfig | list[list[ScalarConfig]]
11
+ Integer = str | int
12
+ Scalar = str | pgt.Scalar
13
+ ManyScalars = str | list[Scalar]
14
+ OneOrManyScalars = Scalar | list[Scalar]
15
+ OneOrMany2DScalars = Scalar | list[list[Scalar]]
16
16
 
17
17
 
18
18
  class StrictBaseModel(BaseModel):
@@ -20,24 +20,20 @@ class StrictBaseModel(BaseModel):
20
20
 
21
21
 
22
22
  class SkylineParameterModel(StrictBaseModel):
23
- value: ManyScalarsConfig
24
- change_times: ManyScalarsConfig
23
+ value: ManyScalars
24
+ change_times: ManyScalars
25
25
 
26
26
 
27
27
  class SkylineVectorModel(StrictBaseModel):
28
- value: str | list[OneOrManyScalarsConfig]
29
- change_times: ManyScalarsConfig
28
+ value: str | list[OneOrManyScalars]
29
+ change_times: ManyScalars
30
30
 
31
31
 
32
32
  class SkylineMatrixModel(StrictBaseModel):
33
- value: str | list[OneOrMany2DScalarsConfig]
34
- change_times: ManyScalarsConfig
33
+ value: str | list[OneOrMany2DScalars]
34
+ change_times: ManyScalars
35
35
 
36
36
 
37
- SkylineParameterConfig = ScalarConfig | SkylineParameterModel
38
- SkylineVectorConfig = (
39
- str | pgt.Scalar | list[SkylineParameterConfig] | SkylineVectorModel
40
- )
41
- SkylineMatrixConfig = (
42
- str | pgt.Scalar | list[SkylineVectorConfig] | SkylineMatrixModel | None
43
- )
37
+ SkylineParameter = Scalar | SkylineParameterModel
38
+ SkylineVector = str | pgt.Scalar | list[SkylineParameter] | SkylineVectorModel
39
+ SkylineMatrix = str | pgt.Scalar | list[SkylineVector] | SkylineMatrixModel | None
@@ -10,7 +10,7 @@ import pandas as pd
10
10
  from numpy.random import Generator, default_rng
11
11
  from tqdm import tqdm
12
12
 
13
- from phylogenie.generators.configs import DistributionConfig, StrictBaseModel
13
+ import phylogenie.generators.configs as cfg
14
14
 
15
15
 
16
16
  class DataType(str, Enum):
@@ -22,12 +22,12 @@ DATA_DIRNAME = "data"
22
22
  METADATA_FILENAME = "metadata.csv"
23
23
 
24
24
 
25
- class DatasetGenerator(ABC, StrictBaseModel):
25
+ class DatasetGenerator(ABC, cfg.StrictBaseModel):
26
26
  output_dir: str = "phylogenie-outputs"
27
27
  n_samples: int | dict[str, int] = 1
28
28
  n_jobs: int = -1
29
29
  seed: int | None = None
30
- context: dict[str, DistributionConfig] | None = None
30
+ context: dict[str, cfg.Distribution] | None = None
31
31
 
32
32
  @abstractmethod
33
33
  def _generate_one(
@@ -29,7 +29,7 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
29
29
  ).tolist()
30
30
 
31
31
 
32
- def integer(x: cfg.IntegerConfig, data: dict[str, Any]) -> int:
32
+ def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
33
33
  if isinstance(x, str):
34
34
  e = _eval_expression(x, data)
35
35
  if isinstance(e, int):
@@ -40,7 +40,7 @@ def integer(x: cfg.IntegerConfig, data: dict[str, Any]) -> int:
40
40
  return x
41
41
 
42
42
 
43
- def scalar(x: cfg.ScalarConfig, data: dict[str, Any]) -> pgt.Scalar:
43
+ def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
44
44
  if isinstance(x, str):
45
45
  e = _eval_expression(x, data)
46
46
  if isinstance(e, pgt.Scalar):
@@ -51,7 +51,7 @@ def scalar(x: cfg.ScalarConfig, data: dict[str, Any]) -> pgt.Scalar:
51
51
  return x
52
52
 
53
53
 
54
- def many_scalars(x: cfg.ManyScalarsConfig, data: dict[str, Any]) -> pgt.ManyScalars:
54
+ def many_scalars(x: cfg.ManyScalars, data: dict[str, Any]) -> pgt.ManyScalars:
55
55
  if isinstance(x, str):
56
56
  e = _eval_expression(x, data)
57
57
  if tg.is_many_scalars(e):
@@ -63,7 +63,7 @@ def many_scalars(x: cfg.ManyScalarsConfig, data: dict[str, Any]) -> pgt.ManyScal
63
63
 
64
64
 
65
65
  def one_or_many_scalars(
66
- x: cfg.OneOrManyScalarsConfig, data: dict[str, Any]
66
+ x: cfg.OneOrManyScalars, data: dict[str, Any]
67
67
  ) -> pgt.OneOrManyScalars:
68
68
  if isinstance(x, str):
69
69
  e = _eval_expression(x, data)
@@ -78,9 +78,9 @@ def one_or_many_scalars(
78
78
 
79
79
 
80
80
  def skyline_parameter(
81
- x: cfg.SkylineParameterConfig, data: dict[str, Any]
81
+ x: cfg.SkylineParameter, data: dict[str, Any]
82
82
  ) -> SkylineParameterLike:
83
- if isinstance(x, cfg.ScalarConfig):
83
+ if isinstance(x, cfg.Scalar):
84
84
  return scalar(x, data)
85
85
  return SkylineParameter(
86
86
  value=many_scalars(x.value, data),
@@ -89,7 +89,7 @@ def skyline_parameter(
89
89
 
90
90
 
91
91
  def skyline_vector(
92
- x: cfg.SkylineVectorConfig, data: dict[str, Any]
92
+ x: cfg.SkylineVector, data: dict[str, Any]
93
93
  ) -> SkylineVectorCoercible:
94
94
  if isinstance(x, str):
95
95
  e = _eval_expression(x, data)
@@ -132,7 +132,7 @@ def skyline_vector(
132
132
 
133
133
 
134
134
  def one_or_many_2D_scalars(
135
- x: cfg.OneOrMany2DScalarsConfig, data: dict[str, Any]
135
+ x: cfg.OneOrMany2DScalars, data: dict[str, Any]
136
136
  ) -> pgt.OneOrMany2DScalars:
137
137
  if isinstance(x, str):
138
138
  e = _eval_expression(x, data)
@@ -147,7 +147,7 @@ def one_or_many_2D_scalars(
147
147
 
148
148
 
149
149
  def skyline_matrix(
150
- x: cfg.SkylineMatrixConfig, data: dict[str, Any]
150
+ x: cfg.SkylineMatrix, data: dict[str, Any]
151
151
  ) -> SkylineMatrixCoercible | None:
152
152
  if x is None:
153
153
  return None
@@ -40,15 +40,20 @@ class ParameterizationType(str, Enum):
40
40
 
41
41
  class TreeDatasetGenerator(DatasetGenerator):
42
42
  data_type: Literal[DataType.TREES] = DataType.TREES
43
- min_tips: cfg.IntegerConfig = 1
44
- max_tips: cfg.IntegerConfig | None = None
45
- max_time: cfg.ScalarConfig = np.inf
43
+ min_tips: cfg.Integer = 1
44
+ max_tips: cfg.Integer | None = None
45
+ max_time: cfg.Scalar = np.inf
46
46
  init_state: str | None = None
47
- sampling_probability_at_present: cfg.ScalarConfig = 0.0
47
+ sampling_probability_at_present: cfg.Scalar = 0.0
48
48
  max_tries: int | None = None
49
+ notification_probability: cfg.Scalar = 0.0
50
+ max_notified_contacts: cfg.Integer = 1
51
+ samplable_states_after_notification: list[str] | None = None
52
+ sampling_rate_after_notification: cfg.SkylineParameter = np.inf
53
+ contacts_removal_probability: cfg.SkylineParameter = 1
49
54
 
50
55
  def simulate_one(self, rng: Generator, data: dict[str, Any]) -> Tree | None:
51
- events = self._get_events(rng, data)
56
+ events = self._get_events(data)
52
57
  init_state = (
53
58
  self.init_state
54
59
  if self.init_state is None
@@ -66,12 +71,21 @@ class TreeDatasetGenerator(DatasetGenerator):
66
71
  sampling_probability_at_present=scalar(
67
72
  self.sampling_probability_at_present, data
68
73
  ),
74
+ notification_probability=scalar(self.notification_probability, data),
75
+ max_notified_contacts=integer(self.max_notified_contacts, data),
76
+ samplable_states_after_notification=self.samplable_states_after_notification,
77
+ sampling_rate_after_notification=skyline_parameter(
78
+ self.sampling_rate_after_notification, data
79
+ ),
80
+ contacts_removal_probability=skyline_parameter(
81
+ self.contacts_removal_probability, data
82
+ ),
69
83
  max_tries=self.max_tries,
70
84
  seed=int(rng.integers(2**32)),
71
85
  )
72
86
 
73
87
  @abstractmethod
74
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]: ...
88
+ def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
75
89
 
76
90
  def _generate_one(
77
91
  self, filename: str, rng: Generator, data: dict[str, Any]
@@ -85,15 +99,15 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
85
99
  parameterization: Literal[ParameterizationType.CANONICAL] = (
86
100
  ParameterizationType.CANONICAL
87
101
  )
88
- sampling_rates: cfg.SkylineVectorConfig
89
- birth_rates: cfg.SkylineVectorConfig = 0
90
- death_rates: cfg.SkylineVectorConfig = 0
91
- removal_probabilities: cfg.SkylineVectorConfig = 0
92
- migration_rates: cfg.SkylineMatrixConfig = None
93
- birth_rates_among_states: cfg.SkylineMatrixConfig = None
94
- states: list[str] | None = None
95
-
96
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
102
+ states: list[str]
103
+ sampling_rates: cfg.SkylineVector
104
+ birth_rates: cfg.SkylineVector = 0
105
+ death_rates: cfg.SkylineVector = 0
106
+ removal_probabilities: cfg.SkylineVector = 0
107
+ migration_rates: cfg.SkylineMatrix = None
108
+ birth_rates_among_states: cfg.SkylineMatrix = None
109
+
110
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
97
111
  return get_canonical_events(
98
112
  states=self.states,
99
113
  sampling_rates=skyline_vector(self.sampling_rates, data),
@@ -111,15 +125,15 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
111
125
  parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
112
126
  ParameterizationType.EPIDEMIOLOGICAL
113
127
  )
114
- states: list[str] | None = None
115
- reproduction_numbers: cfg.SkylineVectorConfig = 0
116
- become_uninfectious_rates: cfg.SkylineVectorConfig = 0
117
- sampling_proportions: cfg.SkylineVectorConfig = 1
118
- removal_probabilities: cfg.SkylineVectorConfig = 1
119
- migration_rates: cfg.SkylineMatrixConfig = None
120
- reproduction_numbers_among_states: cfg.SkylineMatrixConfig = None
121
-
122
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
128
+ states: list[str]
129
+ reproduction_numbers: cfg.SkylineVector = 0
130
+ become_uninfectious_rates: cfg.SkylineVector = 0
131
+ sampling_proportions: cfg.SkylineVector = 1
132
+ removal_probabilities: cfg.SkylineVector = 1
133
+ migration_rates: cfg.SkylineMatrix = None
134
+ reproduction_numbers_among_states: cfg.SkylineMatrix = None
135
+
136
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
123
137
  return get_epidemiological_events(
124
138
  states=self.states,
125
139
  reproduction_numbers=skyline_vector(self.reproduction_numbers, data),
@@ -137,15 +151,15 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
137
151
 
138
152
  class FBDTreeDatasetGenerator(TreeDatasetGenerator):
139
153
  parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
140
- states: list[str] | None = None
141
- diversification: cfg.SkylineVectorConfig = 0
142
- turnover: cfg.SkylineVectorConfig = 0
143
- sampling_proportions: cfg.SkylineVectorConfig = 1
144
- removal_probabilities: cfg.SkylineVectorConfig = 0
145
- migration_rates: cfg.SkylineMatrixConfig = None
146
- diversification_between_types: cfg.SkylineMatrixConfig = None
147
-
148
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
154
+ states: list[str]
155
+ diversification: cfg.SkylineVector = 0
156
+ turnover: cfg.SkylineVector = 0
157
+ sampling_proportions: cfg.SkylineVector = 1
158
+ removal_probabilities: cfg.SkylineVector = 0
159
+ migration_rates: cfg.SkylineMatrix = None
160
+ diversification_between_types: cfg.SkylineMatrix = None
161
+
162
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
149
163
  return get_FBD_events(
150
164
  states=self.states,
151
165
  diversification=skyline_vector(self.diversification, data),
@@ -161,11 +175,11 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
161
175
 
162
176
  class BDTreeDatasetGenerator(TreeDatasetGenerator):
163
177
  parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
164
- reproduction_number: cfg.SkylineParameterConfig
165
- infectious_period: cfg.SkylineParameterConfig
166
- sampling_proportion: cfg.SkylineParameterConfig = 1
178
+ reproduction_number: cfg.SkylineParameter
179
+ infectious_period: cfg.SkylineParameter
180
+ sampling_proportion: cfg.SkylineParameter = 1
167
181
 
168
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
182
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
169
183
  return get_BD_events(
170
184
  reproduction_number=skyline_parameter(self.reproduction_number, data),
171
185
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -175,12 +189,12 @@ class BDTreeDatasetGenerator(TreeDatasetGenerator):
175
189
 
176
190
  class BDEITreeDatasetGenerator(TreeDatasetGenerator):
177
191
  parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
178
- reproduction_number: cfg.SkylineParameterConfig
179
- infectious_period: cfg.SkylineParameterConfig
180
- incubation_period: cfg.SkylineParameterConfig
181
- sampling_proportion: cfg.SkylineParameterConfig = 1
192
+ reproduction_number: cfg.SkylineParameter
193
+ infectious_period: cfg.SkylineParameter
194
+ incubation_period: cfg.SkylineParameter
195
+ sampling_proportion: cfg.SkylineParameter = 1
182
196
 
183
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
197
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
184
198
  return get_BDEI_events(
185
199
  reproduction_number=skyline_parameter(self.reproduction_number, data),
186
200
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -191,13 +205,13 @@ class BDEITreeDatasetGenerator(TreeDatasetGenerator):
191
205
 
192
206
  class BDSSTreeDatasetGenerator(TreeDatasetGenerator):
193
207
  parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
194
- reproduction_number: cfg.SkylineParameterConfig
195
- infectious_period: cfg.SkylineParameterConfig
196
- superspreading_ratio: cfg.SkylineParameterConfig
197
- superspreaders_proportion: cfg.SkylineParameterConfig
198
- sampling_proportion: cfg.SkylineParameterConfig = 1
208
+ reproduction_number: cfg.SkylineParameter
209
+ infectious_period: cfg.SkylineParameter
210
+ superspreading_ratio: cfg.SkylineParameter
211
+ superspreaders_proportion: cfg.SkylineParameter
212
+ sampling_proportion: cfg.SkylineParameter = 1
199
213
 
200
- def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
214
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
201
215
  return get_BDSS_events(
202
216
  reproduction_number=skyline_parameter(self.reproduction_number, data),
203
217
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -1,32 +1,28 @@
1
- from typing import TypeGuard
1
+ from typing import Any, TypeGuard
2
2
 
3
3
  import phylogenie.generators.configs as cfg
4
4
  import phylogenie.typings as pgt
5
5
 
6
6
 
7
- def is_list(x: object) -> TypeGuard[list[object]]:
7
+ def is_list(x: Any) -> TypeGuard[list[Any]]:
8
8
  return isinstance(x, list)
9
9
 
10
10
 
11
- def is_list_of_scalar_configs(x: object) -> TypeGuard[list[cfg.ScalarConfig]]:
12
- return is_list(x) and all(isinstance(v, cfg.ScalarConfig) for v in x)
11
+ def is_list_of_scalar_configs(x: Any) -> TypeGuard[list[cfg.Scalar]]:
12
+ return is_list(x) and all(isinstance(v, cfg.Scalar) for v in x)
13
13
 
14
14
 
15
15
  def is_list_of_skyline_parameter_configs(
16
- x: object,
17
- ) -> TypeGuard[list[cfg.SkylineParameterConfig]]:
18
- return is_list(x) and all(isinstance(v, cfg.SkylineParameterConfig) for v in x)
16
+ x: Any,
17
+ ) -> TypeGuard[list[cfg.SkylineParameter]]:
18
+ return is_list(x) and all(isinstance(v, cfg.SkylineParameter) for v in x)
19
19
 
20
20
 
21
- def is_skyline_vector_config(
22
- x: object,
23
- ) -> TypeGuard[cfg.SkylineVectorConfig]:
21
+ def is_skyline_vector_config(x: Any) -> TypeGuard[cfg.SkylineVector]:
24
22
  return isinstance(
25
23
  x, str | pgt.Scalar | cfg.SkylineVectorModel
26
24
  ) or is_list_of_skyline_parameter_configs(x)
27
25
 
28
26
 
29
- def is_list_of_skyline_vector_configs(
30
- x: object,
31
- ) -> TypeGuard[list[cfg.SkylineVectorConfig]]:
27
+ def is_list_of_skyline_vector_configs(x: Any) -> TypeGuard[list[cfg.SkylineVector]]:
32
28
  return is_list(x) and all(is_skyline_vector_config(v) for v in x)
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Callable, Iterator
2
- from typing import TypeGuard, Union, overload
2
+ from typing import Any, TypeGuard, Union, overload
3
3
 
4
4
  import phylogenie.typeguards as tg
5
5
  import phylogenie.typings as pgt
@@ -20,7 +20,7 @@ SkylineMatrixOperand = Union[SkylineVectorOperand, "SkylineMatrix"]
20
20
  SkylineMatrixCoercible = Union[pgt.OneOrMany[SkylineVectorCoercible], "SkylineMatrix"]
21
21
 
22
22
 
23
- def is_skyline_matrix_operand(x: object) -> TypeGuard[SkylineMatrixOperand]:
23
+ def is_skyline_matrix_operand(x: Any) -> TypeGuard[SkylineMatrixOperand]:
24
24
  return isinstance(x, SkylineMatrix) or is_skyline_vector_operand(x)
25
25
 
26
26
 
@@ -142,7 +142,7 @@ class SkylineMatrix:
142
142
  def __bool__(self) -> bool:
143
143
  return any(self.params)
144
144
 
145
- def __eq__(self, other: object) -> bool:
145
+ def __eq__(self, other: Any) -> bool:
146
146
  return isinstance(other, SkylineMatrix) and self.params == other.params
147
147
 
148
148
  def __repr__(self) -> str:
@@ -1,6 +1,6 @@
1
1
  from bisect import bisect_right
2
2
  from collections.abc import Callable
3
- from typing import TypeGuard, Union
3
+ from typing import Any, TypeGuard, Union
4
4
 
5
5
  import phylogenie.typeguards as tg
6
6
  import phylogenie.typings as pgt
@@ -8,12 +8,12 @@ import phylogenie.typings as pgt
8
8
  SkylineParameterLike = Union[pgt.Scalar, "SkylineParameter"]
9
9
 
10
10
 
11
- def is_skyline_parameter_like(x: object) -> TypeGuard[SkylineParameterLike]:
11
+ def is_skyline_parameter_like(x: Any) -> TypeGuard[SkylineParameterLike]:
12
12
  return isinstance(x, pgt.Scalar | SkylineParameter)
13
13
 
14
14
 
15
15
  def is_many_skyline_parameters_like(
16
- x: object,
16
+ x: Any,
17
17
  ) -> TypeGuard[pgt.Many[SkylineParameterLike]]:
18
18
  return tg.is_many(x) and all(is_skyline_parameter_like(v) for v in x)
19
19
 
@@ -107,7 +107,7 @@ class SkylineParameter:
107
107
  def __bool__(self) -> bool:
108
108
  return any(self.value)
109
109
 
110
- def __eq__(self, other: object) -> bool:
110
+ def __eq__(self, other: Any) -> bool:
111
111
  return isinstance(other, SkylineParameter) and (
112
112
  self.value == other.value and self.change_times == other.change_times
113
113
  )
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Callable, Iterator
2
- from typing import TypeGuard, Union, overload
2
+ from typing import Any, TypeGuard, Union, overload
3
3
 
4
4
  import phylogenie.typeguards as tg
5
5
  import phylogenie.typings as pgt
@@ -16,26 +16,24 @@ SkylineVectorLike = Union[pgt.Many[SkylineParameterLike], "SkylineVector"]
16
16
  SkylineVectorCoercible = Union[pgt.OneOrMany[SkylineParameterLike], "SkylineVector"]
17
17
 
18
18
 
19
- def is_skyline_vector_operand(x: object) -> TypeGuard[SkylineVectorOperand]:
19
+ def is_skyline_vector_operand(x: Any) -> TypeGuard[SkylineVectorOperand]:
20
20
  return isinstance(x, SkylineVector) or is_skyline_parameter_like(x)
21
21
 
22
22
 
23
- def is_skyline_vector_like(x: object) -> TypeGuard[SkylineVectorLike]:
23
+ def is_skyline_vector_like(x: Any) -> TypeGuard[SkylineVectorLike]:
24
24
  return isinstance(x, SkylineVector) or is_many_skyline_parameters_like(x)
25
25
 
26
26
 
27
- def is_skyline_vector_coercible(
28
- x: object,
29
- ) -> TypeGuard[SkylineVectorCoercible]:
27
+ def is_skyline_vector_coercible(x: Any) -> TypeGuard[SkylineVectorCoercible]:
30
28
  return is_skyline_parameter_like(x) or is_skyline_vector_like(x)
31
29
 
32
30
 
33
- def is_many_skyline_vectors_like(x: object) -> TypeGuard[pgt.Many[SkylineVectorLike]]:
31
+ def is_many_skyline_vectors_like(x: Any) -> TypeGuard[pgt.Many[SkylineVectorLike]]:
34
32
  return tg.is_many(x) and all(is_skyline_vector_like(v) for v in x)
35
33
 
36
34
 
37
35
  def is_many_skyline_vectors_coercible(
38
- x: object,
36
+ x: Any,
39
37
  ) -> TypeGuard[pgt.Many[SkylineVectorCoercible]]:
40
38
  return tg.is_many(x) and all(is_skyline_vector_coercible(v) for v in x)
41
39
 
@@ -131,7 +129,7 @@ class SkylineVector:
131
129
  def __bool__(self) -> bool:
132
130
  return any(self.params)
133
131
 
134
- def __eq__(self, other: object) -> bool:
132
+ def __eq__(self, other: Any) -> bool:
135
133
  return isinstance(other, SkylineVector) and self.params == other.params
136
134
 
137
135
  def __repr__(self) -> str:
phylogenie/tree.py CHANGED
@@ -1,6 +1,5 @@
1
- from __future__ import annotations
2
-
3
- from typing import Iterator
1
+ from collections.abc import Iterator
2
+ from typing import Any
4
3
 
5
4
 
6
5
  class Tree:
@@ -9,29 +8,30 @@ class Tree:
9
8
  self.branch_length = branch_length
10
9
  self.parent: Tree | None = None
11
10
  self.children: list[Tree] = []
11
+ self._features: dict[str, Any] = {}
12
12
 
13
- def add_child(self, child: Tree) -> Tree:
13
+ def add_child(self, child: "Tree") -> "Tree":
14
14
  child.parent = self
15
15
  self.children.append(child)
16
16
  return self
17
17
 
18
- def preorder_traversal(self) -> Iterator[Tree]:
18
+ def preorder_traversal(self) -> Iterator["Tree"]:
19
19
  yield self
20
20
  for child in self.children:
21
21
  yield from child.preorder_traversal()
22
22
 
23
- def postorder_traversal(self) -> Iterator[Tree]:
23
+ def postorder_traversal(self) -> Iterator["Tree"]:
24
24
  for child in self.children:
25
25
  yield from child.postorder_traversal()
26
26
  yield self
27
27
 
28
- def get_node(self, id: str) -> Tree:
28
+ def get_node(self, id: str) -> "Tree":
29
29
  for node in self:
30
30
  if node.id == id:
31
31
  return node
32
32
  raise ValueError(f"Node with id {id} not found.")
33
33
 
34
- def get_leaves(self) -> list[Tree]:
34
+ def get_leaves(self) -> list["Tree"]:
35
35
  return [node for node in self if not node.children]
36
36
 
37
37
  def get_time(self) -> float:
@@ -40,13 +40,22 @@ class Tree:
40
40
  raise ValueError(f"Branch length of node {self.id} is not set.")
41
41
  return self.branch_length + parent_time
42
42
 
43
- def copy(self) -> Tree:
43
+ def is_leaf(self) -> bool:
44
+ return not self.children
45
+
46
+ def copy(self) -> "Tree":
44
47
  new_tree = Tree(self.id, self.branch_length)
45
48
  for child in self.children:
46
49
  new_tree.add_child(child.copy())
47
50
  return new_tree
48
51
 
49
- def __iter__(self) -> Iterator[Tree]:
52
+ def get(self, key: str, default: Any = None) -> Any:
53
+ return self._features.get(key, default)
54
+
55
+ def set(self, key: str, value: Any) -> None:
56
+ self._features[key] = value
57
+
58
+ def __iter__(self) -> Iterator["Tree"]:
50
59
  return self.preorder_traversal()
51
60
 
52
61
  def __repr__(self) -> str:
@@ -4,6 +4,7 @@ from phylogenie.treesimulator.events import (
4
4
  get_BDEI_events,
5
5
  get_BDSS_events,
6
6
  get_canonical_events,
7
+ get_contact_tracing_events,
7
8
  get_epidemiological_events,
8
9
  get_FBD_events,
9
10
  )
@@ -15,6 +16,7 @@ __all__ = [
15
16
  "get_BDEI_events",
16
17
  "get_BDSS_events",
17
18
  "get_canonical_events",
19
+ "get_contact_tracing_events",
18
20
  "get_epidemiological_events",
19
21
  "get_FBD_events",
20
22
  "simulate_tree",
@@ -1,95 +1,84 @@
1
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Sequence
3
+ from dataclasses import dataclass
2
4
 
3
- from numpy.random import Generator
5
+ import numpy as np
4
6
 
5
7
  from phylogenie.skyline import (
6
8
  SkylineMatrixCoercible,
9
+ SkylineParameter,
7
10
  SkylineParameterLike,
8
11
  SkylineVectorCoercible,
9
12
  skyline_matrix,
10
13
  skyline_parameter,
11
14
  skyline_vector,
12
15
  )
13
- from phylogenie.treesimulator.model import Model
16
+ from phylogenie.treesimulator.model import Model, get_CT_state
14
17
 
15
18
  INFECTIOUS_STATE = "I"
16
19
  EXPOSED_STATE = "E"
17
20
  SUPERSPREADER_STATE = "S"
18
21
 
19
22
 
23
+ @dataclass
20
24
  class Event(ABC):
21
- def __init__(
22
- self,
23
- rate: SkylineParameterLike,
24
- state: str | None = None,
25
- ):
26
- self.rate = skyline_parameter(rate)
27
- self.state = state
25
+ rate: SkylineParameter
26
+ state: str
28
27
 
29
28
  def get_propensity(self, model: Model, time: float) -> float:
30
- return self.rate.get_value_at_time(time) * model.count_leaves(self.state)
29
+ n_individuals = model.count_individuals(self.state)
30
+ rate = self.rate.get_value_at_time(time)
31
+ if rate == np.inf and not n_individuals:
32
+ return 0
33
+ return rate * n_individuals
31
34
 
32
35
  @abstractmethod
33
- def apply(self, rng: Generator, model: Model, time: float) -> None: ...
36
+ def apply(self, model: Model, time: float) -> None: ...
34
37
 
35
38
 
39
+ @dataclass
36
40
  class BirthEvent(Event):
37
- def __init__(
38
- self,
39
- rate: SkylineParameterLike,
40
- state: str | None = None,
41
- child_state: str | None = None,
42
- ):
43
- super().__init__(rate, state)
44
- self.child_state = child_state
41
+ child_state: str
45
42
 
46
- def apply(self, rng: Generator, model: Model, time: float) -> None:
47
- node = model.get_random_leaf(self.state, rng)
48
- model.add_child(node, time, True, self.child_state)
43
+ def apply(self, model: Model, time: float) -> None:
44
+ individual = model.get_random_individual(self.state)
45
+ model.birth_from(individual, self.child_state, time)
49
46
 
50
47
 
51
48
  class DeathEvent(Event):
52
- def apply(self, rng: Generator, model: Model, time: float) -> None:
53
- node = model.get_random_leaf(self.state, rng)
54
- model.remove(node)
49
+ def apply(self, model: Model, time: float) -> None:
50
+ individual = model.get_random_individual(self.state)
51
+ model.remove(individual, time)
55
52
 
56
53
 
54
+ @dataclass
57
55
  class MigrationEvent(Event):
58
- def __init__(self, rate: SkylineParameterLike, state: str, target_state: str):
59
- super().__init__(rate, state)
60
- self.target_state = target_state
56
+ target_state: str
61
57
 
62
- def apply(self, rng: Generator, model: Model, time: float) -> None:
63
- node = model.get_random_leaf(self.state, rng)
64
- model.add_child(node, time, False, self.target_state)
58
+ def apply(self, model: Model, time: float) -> None:
59
+ individual = model.get_random_individual(self.state)
60
+ model.migrate(individual, self.target_state, time)
65
61
 
66
62
 
63
+ @dataclass
67
64
  class SamplingEvent(Event):
68
- def __init__(
69
- self,
70
- rate: SkylineParameterLike,
71
- state: str | None = None,
72
- removal_probability: SkylineParameterLike = 0,
73
- ):
74
- super().__init__(rate, state)
75
- self.removal_probability = skyline_parameter(removal_probability)
65
+ removal_probability: SkylineParameter
76
66
 
77
- def apply(self, rng: Generator, model: Model, time: float) -> None:
78
- node = model.get_random_leaf(self.state, rng)
79
- remove = rng.random() < self.removal_probability.get_value_at_time(time)
80
- model.sample(node, time, remove)
67
+ def apply(self, model: Model, time: float) -> None:
68
+ individual = model.get_random_individual(self.state)
69
+ model.sample(individual, time, self.removal_probability.get_value_at_time(time))
81
70
 
82
71
 
83
72
  def get_canonical_events(
73
+ states: Sequence[str],
84
74
  sampling_rates: SkylineVectorCoercible,
85
75
  birth_rates: SkylineVectorCoercible = 0,
86
76
  death_rates: SkylineVectorCoercible = 0,
87
77
  removal_probabilities: SkylineVectorCoercible = 0,
88
78
  migration_rates: SkylineMatrixCoercible | None = None,
89
79
  birth_rates_among_states: SkylineMatrixCoercible | None = None,
90
- states: list[str] | None = None,
91
80
  ) -> list[Event]:
92
- N = 1 if states is None else len(states)
81
+ N = len(states)
93
82
 
94
83
  birth_rates = skyline_vector(birth_rates, N)
95
84
  death_rates = skyline_vector(death_rates, N)
@@ -97,43 +86,38 @@ def get_canonical_events(
97
86
  removal_probabilities = skyline_vector(removal_probabilities, N)
98
87
 
99
88
  events: list[Event] = []
100
- for i in range(N):
101
- state = None if states is None else states[i]
89
+ for i, state in enumerate(states):
102
90
  events.append(BirthEvent(birth_rates[i], state, state))
103
91
  events.append(DeathEvent(death_rates[i], state))
104
92
  events.append(SamplingEvent(sampling_rates[i], state, removal_probabilities[i]))
105
93
 
106
- if states is not None and migration_rates is not None:
94
+ if migration_rates is not None:
107
95
  migration_rates = skyline_matrix(migration_rates, N, N - 1)
108
96
  for i, state in enumerate(states):
109
97
  for j, other_state in enumerate([s for s in states if s != state]):
110
98
  events.append(MigrationEvent(migration_rates[i, j], state, other_state))
111
- elif migration_rates is not None:
112
- raise ValueError(f"Migration rates require states to be provided.")
113
99
 
114
- if states is not None and birth_rates_among_states is not None:
100
+ if birth_rates_among_states is not None:
115
101
  birth_rates_among_states = skyline_matrix(birth_rates_among_states, N, N - 1)
116
102
  for i, state in enumerate(states):
117
103
  for j, other_state in enumerate([s for s in states if s != state]):
118
104
  events.append(
119
105
  BirthEvent(birth_rates_among_states[i, j], state, other_state)
120
106
  )
121
- elif birth_rates_among_states is not None:
122
- raise ValueError(f"Birth rates among states require states to be provided.")
123
107
 
124
108
  return [event for event in events if event.rate]
125
109
 
126
110
 
127
111
  def get_epidemiological_events(
112
+ states: Sequence[str],
128
113
  sampling_proportions: SkylineVectorCoercible = 1,
129
114
  reproduction_numbers: SkylineVectorCoercible = 0,
130
115
  become_uninfectious_rates: SkylineVectorCoercible = 0,
131
116
  removal_probabilities: SkylineVectorCoercible = 1,
132
117
  migration_rates: SkylineMatrixCoercible | None = None,
133
118
  reproduction_numbers_among_states: SkylineMatrixCoercible | None = None,
134
- states: list[str] | None = None,
135
119
  ) -> list[Event]:
136
- N = 1 if states is None else len(states)
120
+ N = len(states)
137
121
 
138
122
  reproduction_numbers = skyline_vector(reproduction_numbers, N)
139
123
  become_uninfectious_rates = skyline_vector(become_uninfectious_rates, N)
@@ -143,16 +127,14 @@ def get_epidemiological_events(
143
127
  birth_rates = reproduction_numbers * become_uninfectious_rates
144
128
  sampling_rates = become_uninfectious_rates * sampling_proportions
145
129
  death_rates = become_uninfectious_rates - removal_probabilities * sampling_rates
146
- birth_rates_among_states = None
147
- if states is None and reproduction_numbers_among_states is not None:
148
- raise ValueError(
149
- f"Reproduction numbers among states require states to be provided."
150
- )
151
- elif reproduction_numbers_among_states is not None:
152
- birth_rates_among_states = (
130
+ birth_rates_among_states = (
131
+ (
153
132
  skyline_matrix(reproduction_numbers_among_states, N, N - 1)
154
133
  * become_uninfectious_rates
155
134
  )
135
+ if reproduction_numbers_among_states is not None
136
+ else None
137
+ )
156
138
 
157
139
  return get_canonical_events(
158
140
  states=states,
@@ -166,15 +148,15 @@ def get_epidemiological_events(
166
148
 
167
149
 
168
150
  def get_FBD_events(
151
+ states: Sequence[str],
169
152
  diversification: SkylineVectorCoercible = 0,
170
153
  turnover: SkylineVectorCoercible = 0,
171
154
  sampling_proportions: SkylineVectorCoercible = 1,
172
155
  removal_probabilities: SkylineVectorCoercible = 0,
173
156
  migration_rates: SkylineMatrixCoercible | None = None,
174
157
  diversification_between_types: SkylineMatrixCoercible | None = None,
175
- states: list[str] | None = None,
176
158
  ) -> list[Event]:
177
- N = 1 if states is None else len(states)
159
+ N = len(states)
178
160
 
179
161
  diversification = skyline_vector(diversification, N)
180
162
  turnover = skyline_vector(turnover, N)
@@ -185,15 +167,11 @@ def get_FBD_events(
185
167
  death_rates = turnover * birth_rates
186
168
  sampling_rates_dividend = 1 - removal_probabilities * sampling_proportions
187
169
  sampling_rates = sampling_proportions * death_rates / sampling_rates_dividend
188
- birth_rates_among_states = None
189
- if states is None and diversification_between_types is not None:
190
- raise ValueError(
191
- f"Diversification rates among states require states to be provided."
192
- )
193
- elif diversification_between_types is not None:
194
- birth_rates_among_states = (
195
- skyline_matrix(diversification_between_types, N, N - 1) + death_rates
196
- )
170
+ birth_rates_among_states = (
171
+ (skyline_matrix(diversification_between_types, N, N - 1) + death_rates)
172
+ if diversification_between_types is not None
173
+ else None
174
+ )
197
175
 
198
176
  return get_canonical_events(
199
177
  states=states,
@@ -212,6 +190,7 @@ def get_BD_events(
212
190
  sampling_proportion: SkylineParameterLike = 1,
213
191
  ) -> list[Event]:
214
192
  return get_epidemiological_events(
193
+ states=[INFECTIOUS_STATE],
215
194
  reproduction_numbers=reproduction_number,
216
195
  become_uninfectious_rates=1 / infectious_period,
217
196
  sampling_proportions=sampling_proportion,
@@ -253,3 +232,40 @@ def get_BDSS_events(
253
232
  become_uninfectious_rates=1 / infectious_period,
254
233
  sampling_proportions=sampling_proportion,
255
234
  )
235
+
236
+
237
+ def get_contact_tracing_events(
238
+ events: Sequence[Event],
239
+ samplable_states_after_notification: Sequence[str] | None = None,
240
+ sampling_rate_after_notification: SkylineParameterLike = np.inf,
241
+ contacts_removal_probability: SkylineParameterLike = 1,
242
+ ) -> list[Event]:
243
+ ct_events = list(events)
244
+ for event in events:
245
+ if isinstance(event, MigrationEvent):
246
+ ct_events.append(
247
+ MigrationEvent(
248
+ event.rate,
249
+ get_CT_state(event.state),
250
+ get_CT_state(event.target_state),
251
+ )
252
+ )
253
+ elif isinstance(event, BirthEvent):
254
+ ct_events.append(
255
+ BirthEvent(event.rate, get_CT_state(event.state), event.child_state)
256
+ )
257
+
258
+ for state in (
259
+ samplable_states_after_notification
260
+ if samplable_states_after_notification is not None
261
+ else [e.state for e in events]
262
+ ):
263
+ ct_events.append(
264
+ SamplingEvent(
265
+ skyline_parameter(sampling_rate_after_notification),
266
+ get_CT_state(state),
267
+ skyline_parameter(contacts_removal_probability),
268
+ )
269
+ )
270
+
271
+ return ct_events
@@ -3,9 +3,10 @@ from collections.abc import Sequence
3
3
  import numpy as np
4
4
  from numpy.random import default_rng
5
5
 
6
+ from phylogenie.skyline import SkylineParameterLike
6
7
  from phylogenie.tree import Tree
7
- from phylogenie.treesimulator.events import Event
8
- from phylogenie.treesimulator.model import Model
8
+ from phylogenie.treesimulator.events import Event, get_contact_tracing_events
9
+ from phylogenie.treesimulator.model import Model, is_CT_state
9
10
 
10
11
 
11
12
  def simulate_tree(
@@ -15,6 +16,11 @@ def simulate_tree(
15
16
  max_time: float = np.inf,
16
17
  init_state: str | None = None,
17
18
  sampling_probability_at_present: float = 0.0,
19
+ notification_probability: float = 0,
20
+ max_notified_contacts: int = 1,
21
+ samplable_states_after_notification: Sequence[str] | None = None,
22
+ sampling_rate_after_notification: SkylineParameterLike = np.inf,
23
+ contacts_removal_probability: SkylineParameterLike = 1,
18
24
  max_tries: int | None = None,
19
25
  seed: int | None = None,
20
26
  ) -> Tree | None:
@@ -23,15 +29,19 @@ def simulate_tree(
23
29
  if max_tips is None and max_time == np.inf:
24
30
  raise ValueError("Either max_tips or max_time must be specified.")
25
31
 
32
+ if notification_probability:
33
+ events = get_contact_tracing_events(
34
+ events,
35
+ samplable_states_after_notification,
36
+ sampling_rate_after_notification,
37
+ contacts_removal_probability,
38
+ )
39
+
26
40
  n_tries = 0
27
- states = [e.state for e in events if e.state is not None]
28
- init_state = (
29
- init_state
30
- if init_state is not None
31
- else str(rng.choice(states)) if states else None
32
- )
41
+ root_states = [e.state for e in events if not is_CT_state(e.state)]
33
42
  while max_tries is None or n_tries < max_tries:
34
- model = Model(init_state)
43
+ root_state = init_state if init_state is not None else rng.choice(root_states)
44
+ model = Model(root_state, max_notified_contacts, notification_probability, rng)
35
45
  current_time = 0.0
36
46
  change_times = sorted(set(t for e in events for t in e.rate.change_times))
37
47
  next_change_time = change_times.pop(0) if change_times else np.inf
@@ -39,6 +49,13 @@ def simulate_tree(
39
49
 
40
50
  while current_time < max_time and (n_tips is None or model.n_sampled < n_tips):
41
51
  rates = [e.get_propensity(model, current_time) for e in events]
52
+
53
+ instantaneous_events = [e for e, r in zip(events, rates) if r == np.inf]
54
+ if instantaneous_events:
55
+ event = instantaneous_events[rng.integers(len(instantaneous_events))]
56
+ event.apply(model, current_time)
57
+ continue
58
+
42
59
  if not any(rates):
43
60
  break
44
61
 
@@ -52,11 +69,11 @@ def simulate_tree(
52
69
  continue
53
70
 
54
71
  event_idx = np.searchsorted(np.cumsum(rates) / sum(rates), rng.random())
55
- events[int(event_idx)].apply(rng, model, current_time)
72
+ events[int(event_idx)].apply(model, current_time)
56
73
 
57
- for leaf in model.get_leaves():
74
+ for individual in model.get_population():
58
75
  if rng.random() < sampling_probability_at_present:
59
- model.sample(leaf, current_time, True)
76
+ model.sample(individual, current_time, 1)
60
77
 
61
78
  if model.n_sampled >= min_tips and (
62
79
  max_tips is None or model.n_sampled <= max_tips
@@ -1,71 +1,125 @@
1
1
  from collections import defaultdict
2
+ from dataclasses import dataclass, field
3
+ from typing import ClassVar
2
4
 
3
5
  from numpy.random import Generator, default_rng
4
6
 
5
7
  from phylogenie.tree import Tree
6
8
 
9
+ CT_POSTFIX = "-CT"
7
10
 
8
- class Model:
9
- def __init__(self, init_state: str | None = None):
10
- self._next_id = 0
11
- self._n_sampled = 0
12
- self._leaves: dict[str, Tree] = {}
13
- self._leaf2state: dict[str, str | None] = {}
14
- self._state2leaves: dict[str | None, set[str]] = defaultdict(set)
15
- self._tree = self._get_new_node(init_state, None)
16
11
 
17
- @property
18
- def next_id(self) -> int:
19
- self._next_id += 1
20
- return self._next_id
12
+ def get_CT_state(state: str) -> str:
13
+ return f"{state}{CT_POSTFIX}"
14
+
15
+
16
+ def is_CT_state(state: str) -> bool:
17
+ return state.endswith(CT_POSTFIX)
18
+
19
+
20
+ @dataclass
21
+ class Individual:
22
+ node: Tree
23
+ state: str
24
+ id: int = field(init=False)
25
+ _id_counter: ClassVar[int] = 0
26
+
27
+ def __post_init__(self):
28
+ Individual._id_counter += 1
29
+ self.id = Individual._id_counter
30
+
31
+
32
+ class Model:
33
+ def __init__(
34
+ self,
35
+ init_state: str,
36
+ max_notified_contacts: int = 1,
37
+ notification_probability: float = 0,
38
+ rng: int | Generator | None = None,
39
+ ):
40
+ self._next_node_id = 0
41
+ self._population: dict[int, Individual] = {}
42
+ self._states: dict[str, set[int]] = defaultdict(set)
43
+ self._contacts: dict[int, list[Individual]] = defaultdict(list)
44
+ self._sampled: set[str] = set()
45
+ self._tree = self._get_new_individual(init_state).node
46
+ self._max_notified_contacts = max_notified_contacts
47
+ self._notification_probability = notification_probability
48
+ self._rng = rng if isinstance(rng, Generator) else default_rng(rng)
21
49
 
22
50
  @property
23
51
  def n_sampled(self) -> int:
24
- return self._n_sampled
25
-
26
- def _get_new_node(self, state: str | None, branch_length: float | None) -> Tree:
27
- id = str(self.next_id) if state is None else f"{self.next_id}|{state}"
28
- node = Tree(id, branch_length)
29
- if branch_length is None:
30
- self._leaves[id] = node
31
- self._leaf2state[id] = state
32
- self._state2leaves[state].add(id)
33
- return node
34
-
35
- def remove(self, node_id: str) -> None:
36
- self._state2leaves[self._leaf2state[node_id]].remove(node_id)
37
- self._leaf2state.pop(node_id, None)
38
- self._leaves.pop(node_id)
39
-
40
- def add_child(
41
- self,
42
- node_id: str,
43
- time: float,
44
- stem: bool,
45
- state: str | None,
46
- branch_length: float | None = None,
47
- ) -> None:
48
- node = self._leaves[node_id]
52
+ return len(self._sampled)
53
+
54
+ def _get_new_node(self, state: str) -> Tree:
55
+ self._next_node_id += 1
56
+ return Tree(f"{self._next_node_id}|{state}")
57
+
58
+ def _get_new_individual(self, state: str) -> Individual:
59
+ individual = Individual(self._get_new_node(state), state)
60
+ self._population[individual.id] = individual
61
+ self._states[state].add(individual.id)
62
+ return individual
63
+
64
+ def _set_branch_length(self, node: Tree, time: float) -> None:
49
65
  if node.branch_length is not None:
50
- raise ValueError("Cannot add a child to a node with a set branch length.")
51
- node.add_child(self._get_new_node(state, branch_length))
52
- if stem:
53
- node.add_child(self._get_new_node(self._leaf2state[node.id], None))
66
+ raise ValueError(f"Branch length of node {node.id} is already set.")
54
67
  node.branch_length = (
55
68
  time if node.parent is None else time - node.parent.get_time()
56
69
  )
57
- self.remove(node_id)
58
70
 
59
- def sample(self, node_id: str, time: float, remove: bool) -> None:
60
- self.add_child(node_id, time, not remove, self._leaf2state[node_id], 0.0)
61
- self._n_sampled += 1
71
+ def _stem(self, individual: Individual, time: float) -> None:
72
+ self._set_branch_length(individual.node, time)
73
+ stem_node = self._get_new_node(individual.state)
74
+ individual.node.add_child(stem_node)
75
+ individual.node = stem_node
76
+
77
+ def remove(self, id: int, time: float) -> None:
78
+ individual = self._population[id]
79
+ self._set_branch_length(individual.node, time)
80
+ state = individual.state
81
+ self._population.pop(id)
82
+ self._states[state].remove(id)
83
+
84
+ def migrate(self, id: int, state: str, time: float) -> None:
85
+ individual = self._population[id]
86
+ self._states[individual.state].remove(id)
87
+ individual.state = state
88
+ self._states[state].add(id)
89
+ self._stem(individual, time)
90
+
91
+ def birth_from(self, id: int, state: str, time: float) -> None:
92
+ individual = self._population[id]
93
+ new_individual = self._get_new_individual(state)
94
+ individual.node.add_child(new_individual.node)
95
+ self._stem(individual, time)
96
+ self._contacts[id].append(new_individual)
97
+ self._contacts[new_individual.id].append(individual)
98
+
99
+ def sample(self, id: int, time: float, removal_probability: float) -> None:
100
+ individual = self._population[id]
101
+ if self._rng.random() < removal_probability:
102
+ self._sampled.add(individual.node.id)
103
+ self.remove(id, time)
104
+ else:
105
+ sample_node = self._get_new_node(individual.state)
106
+ sample_node.branch_length = 0.0
107
+ self._sampled.add(sample_node.id)
108
+ individual.node.add_child(sample_node)
109
+ self._stem(individual, time)
110
+
111
+ for contact in self._contacts[id][-self._max_notified_contacts :]:
112
+ if (
113
+ contact.id in self._population
114
+ and not is_CT_state(contact.state)
115
+ and self._rng.random() < self._notification_probability
116
+ ):
117
+ self.migrate(contact.id, get_CT_state(contact.state), time)
62
118
 
63
119
  def get_sampled_tree(self) -> Tree:
64
120
  tree = self._tree.copy()
65
121
  for node in list(tree.postorder_traversal()):
66
- if node.branch_length is None or (
67
- node.branch_length > 0 and not node.children
68
- ):
122
+ if node.id not in self._sampled and not node.children:
69
123
  if node.parent is None:
70
124
  raise ValueError("No samples in the tree.")
71
125
  else:
@@ -83,18 +137,18 @@ class Model:
83
137
  node.parent.children.remove(node)
84
138
  return tree
85
139
 
86
- def get_random_leaf(
87
- self, state: str | None = None, rng: int | Generator | None = None
88
- ) -> str:
89
- rng = rng if isinstance(rng, Generator) else default_rng(rng)
140
+ def get_full_tree(self) -> Tree:
141
+ return self._tree.copy()
142
+
143
+ def get_random_individual(self, state: str | None = None) -> int:
90
144
  if state is None:
91
- return rng.choice(list(self._leaves))
92
- return rng.choice(list(self._state2leaves[state]))
145
+ return self._rng.choice(list(self._population))
146
+ return self._rng.choice(list(self._states[state]))
93
147
 
94
- def get_leaves(self) -> list[str]:
95
- return list(self._leaves)
148
+ def get_population(self) -> list[int]:
149
+ return list(self._population)
96
150
 
97
- def count_leaves(self, state: str | None = None) -> int:
151
+ def count_individuals(self, state: str | None = None) -> int:
98
152
  if state is None:
99
- return len(self._leaves)
100
- return len(self._state2leaves[state])
153
+ return len(self._population)
154
+ return len(self._states[state])
phylogenie/typeguards.py CHANGED
@@ -1,42 +1,42 @@
1
1
  from collections.abc import Sequence
2
- from typing import TypeGuard
2
+ from typing import Any, TypeGuard
3
3
 
4
4
  import phylogenie.typings as pgt
5
5
 
6
6
 
7
- def is_many(x: object) -> TypeGuard[pgt.Many[object]]:
7
+ def is_many(x: Any) -> TypeGuard[pgt.Many[Any]]:
8
8
  return isinstance(x, Sequence) and not isinstance(x, str)
9
9
 
10
10
 
11
- def is_many_scalars(x: object) -> TypeGuard[pgt.Many[pgt.Scalar]]:
11
+ def is_many_scalars(x: Any) -> TypeGuard[pgt.Many[pgt.Scalar]]:
12
12
  return is_many(x) and all(isinstance(i, pgt.Scalar) for i in x)
13
13
 
14
14
 
15
- def is_many_ints(x: object) -> TypeGuard[pgt.Many[int]]:
15
+ def is_many_ints(x: Any) -> TypeGuard[pgt.Many[int]]:
16
16
  return is_many(x) and all(isinstance(i, int) for i in x)
17
17
 
18
18
 
19
- def is_one_or_many_scalars(x: object) -> TypeGuard[pgt.OneOrManyScalars]:
19
+ def is_one_or_many_scalars(x: Any) -> TypeGuard[pgt.OneOrManyScalars]:
20
20
  return isinstance(x, pgt.Scalar) or is_many_scalars(x)
21
21
 
22
22
 
23
- def is_many_one_or_many_scalars(x: object) -> TypeGuard[pgt.Many[pgt.OneOrManyScalars]]:
23
+ def is_many_one_or_many_scalars(x: Any) -> TypeGuard[pgt.Many[pgt.OneOrManyScalars]]:
24
24
  return is_many(x) and all(is_one_or_many_scalars(i) for i in x)
25
25
 
26
26
 
27
- def is_many_2D_scalars(x: object) -> TypeGuard[pgt.Many2DScalars]:
27
+ def is_many_2D_scalars(x: Any) -> TypeGuard[pgt.Many2DScalars]:
28
28
  return is_many(x) and all(is_many_scalars(i) for i in x)
29
29
 
30
30
 
31
- def is_one_or_many_2D_scalars(x: object) -> TypeGuard[pgt.OneOrMany2DScalars]:
31
+ def is_one_or_many_2D_scalars(x: Any) -> TypeGuard[pgt.OneOrMany2DScalars]:
32
32
  return isinstance(x, pgt.Scalar) or is_many_2D_scalars(x)
33
33
 
34
34
 
35
35
  def is_many_one_or_many_2D_scalars(
36
- x: object,
36
+ x: Any,
37
37
  ) -> TypeGuard[pgt.Many[pgt.OneOrMany2DScalars]]:
38
38
  return is_many(x) and all(is_one_or_many_2D_scalars(i) for i in x)
39
39
 
40
40
 
41
- def is_many_3D_scalars(x: object) -> TypeGuard[pgt.Many3DScalars]:
41
+ def is_many_3D_scalars(x: Any) -> TypeGuard[pgt.Many3DScalars]:
42
42
  return is_many(x) and all(is_many_2D_scalars(i) for i in x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: phylogenie
3
- Version: 2.0.5
3
+ Version: 2.0.7
4
4
  Summary: Generate phylogenetic datasets with minimal setup effort
5
5
  Author: Gabriele Marino
6
6
  Author-email: gabmarino.8601@gmail.com
@@ -45,7 +45,7 @@ Phylogenie comes packed with useful features, including:
45
45
  Simply specify the number of cores to use, and Phylogenie handles multiprocessing automatically.
46
46
 
47
47
  - **Pre-implemented parameterizations** 🎯
48
- Include canonical, fossilized birth-death, epidemiological, birth-death with exposed-infectious (BDEI), birth-death with superspreading (BDSS), and more.
48
+ Include canonical, fossilized birth-death, epidemiological, birth-death with exposed-infectious (BDEI), birth-death with superspreading (BDSS), and contact tracing (CT).
49
49
 
50
50
  - **Skyline parameter support** 🪜
51
51
  Support for piece-wise constant parameters.
@@ -0,0 +1,28 @@
1
+ phylogenie/__init__.py,sha256=1w_0H9lg7hI3b-NLjKuzc34GbJJGyjLq9LrlogecTzI,1759
2
+ phylogenie/generators/__init__.py,sha256=zsOxy28-9j9alOQLIgrOAFfmM58NNHO_NEtW-KXQXAY,888
3
+ phylogenie/generators/alisim.py,sha256=dDqlSwLDbRE2u5SZlsq1mArobTBtuk0aeXY3m1N-bWA,2374
4
+ phylogenie/generators/configs.py,sha256=5ZWdKhRUjlNifw7QKXbooKV1fElqfCk_jBGxfcjh8do,969
5
+ phylogenie/generators/dataset.py,sha256=k6RYJpgxOL8a_yMq98WUF-dcJv8TwxaWnde0k13M4J0,2525
6
+ phylogenie/generators/factories.py,sha256=O8wqL-PvZps0Dq6mQa_PTi4vBvky5LkQIy1jjfOUm-4,6944
7
+ phylogenie/generators/trees.py,sha256=jukaVXGcPGzDBEYMGJ1MKqWt4XbAB5EEfuHXDpwKTqM,9173
8
+ phylogenie/generators/typeguards.py,sha256=Qph6ZnQ7wDMUNvB0VWQKlq42f8wkKOnM42cfMqhNov4,862
9
+ phylogenie/io.py,sha256=ZXlofnSh7FX5UJiP0svRHrTraMSNgKa1GiAv0bMz7jU,2854
10
+ phylogenie/main.py,sha256=vtvSpQxBNlYABoFQ25czl-l3fIr4QRo3svWVd-jcArw,1170
11
+ phylogenie/msa.py,sha256=JDGyZUsAq6-m-SQjoCDjAkAZIxfgyl_PDIhdYn5HOow,2064
12
+ phylogenie/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ phylogenie/skyline/__init__.py,sha256=7pF4CUb4ZCLzNYJNhOjpuTOLTRhlK7L6ugfccNqjIGo,620
14
+ phylogenie/skyline/matrix.py,sha256=Gl8OgKjtieG0NwPYiPimKI36gefV8fm_OeorjdXxPTs,9146
15
+ phylogenie/skyline/parameter.py,sha256=EM9qlPt0JhMBy3TbztM0dj24BaGNEy8KWKdTObDKhbI,4644
16
+ phylogenie/skyline/vector.py,sha256=bJP7_FNX_Klt6wXqsyfj0KX3VNj6-dIhzCKSJuQcOV0,7115
17
+ phylogenie/tree.py,sha256=34gcxUoTGfj72EbIlpnrhWGnNFppUVjms3XEn1ZS3-g,1997
18
+ phylogenie/treesimulator/__init__.py,sha256=INPU9LrPdUmt3dYGzWDRoRKrPR9xENcHu44pJVUbyNA,525
19
+ phylogenie/treesimulator/events.py,sha256=X3_0U9qqMpYgh6-7TwQEnlUipANkHz6QTCXlm-qXFQk,9524
20
+ phylogenie/treesimulator/gillespie.py,sha256=Fn-PyVICx3pWtpHko7rf6omf_kqOkkpebSJy56oPKnQ,3216
21
+ phylogenie/treesimulator/model.py,sha256=XpzAicmg2O6K0Trk5YolH-B_HJZxoSauF2wZOMqp-Iw,5559
22
+ phylogenie/typeguards.py,sha256=JtqmbEWJZBRHbWgCvcl6nrWm3VcBfzRbklbTBYHItn0,1325
23
+ phylogenie/typings.py,sha256=O1X6lGKTjJ2YJz3ApQ-rYb_tEJNUIcHdUIeYlSM4s5o,500
24
+ phylogenie-2.0.7.dist-info/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
25
+ phylogenie-2.0.7.dist-info/METADATA,sha256=REo_HAAqLn0XQiHeCPyNPbGuSsbT5p1QupjSxg5Zs_U,5472
26
+ phylogenie-2.0.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
27
+ phylogenie-2.0.7.dist-info/entry_points.txt,sha256=Rt6_usN0FkBX1ZfiqCirjMN9FKOgFLG8rydcQ8kugeE,51
28
+ phylogenie-2.0.7.dist-info/RECORD,,
@@ -1,28 +0,0 @@
1
- phylogenie/__init__.py,sha256=1w_0H9lg7hI3b-NLjKuzc34GbJJGyjLq9LrlogecTzI,1759
2
- phylogenie/generators/__init__.py,sha256=zsOxy28-9j9alOQLIgrOAFfmM58NNHO_NEtW-KXQXAY,888
3
- phylogenie/generators/alisim.py,sha256=Psh_fAHHrJBCd1kqVtz_BdoVVmNvieHJRMmapQncJhM,2288
4
- phylogenie/generators/configs.py,sha256=uiqPvdGSwmItQmhRMT-s1xG7DVnfpwaDTZ_wztL1Pwo,1123
5
- phylogenie/generators/dataset.py,sha256=zJUSC_c1rNfqq6zTkgGQVPmHcSW47HvfQfDdVTamiT8,2557
6
- phylogenie/generators/factories.py,sha256=lJ517JC0T8fVu_2NBNC1GUbMPQfRNeNpO5rGdxubQOA,6998
7
- phylogenie/generators/trees.py,sha256=pSqnYslDMPl4OVsIUFgQu_Rtc2htlBVwhIWnYgSAr3s,8735
8
- phylogenie/generators/typeguards.py,sha256=6WIQIy6huHok4vYqA9ym6g2kvyyXPggBONWwbiHtDiY,925
9
- phylogenie/io.py,sha256=ZXlofnSh7FX5UJiP0svRHrTraMSNgKa1GiAv0bMz7jU,2854
10
- phylogenie/main.py,sha256=vtvSpQxBNlYABoFQ25czl-l3fIr4QRo3svWVd-jcArw,1170
11
- phylogenie/msa.py,sha256=JDGyZUsAq6-m-SQjoCDjAkAZIxfgyl_PDIhdYn5HOow,2064
12
- phylogenie/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- phylogenie/skyline/__init__.py,sha256=7pF4CUb4ZCLzNYJNhOjpuTOLTRhlK7L6ugfccNqjIGo,620
14
- phylogenie/skyline/matrix.py,sha256=BnIreEyCymWzo5jiFtLWXE33ck0pbOFNAfcJW7S2m2s,9147
15
- phylogenie/skyline/parameter.py,sha256=CJ5OEyRQG2Tg1WJWQ1IpfX-6hjJv80Zj8lMoRke5nnQ,4648
16
- phylogenie/skyline/vector.py,sha256=becZlWLeT12mvpDC5MLGXZPKlgmbkY-DEcuhdQpnsSc,7135
17
- phylogenie/tree.py,sha256=nNUQrTtUqgc-Z6CHflqROZ8W2ZL1YWGG7xeWmcKwZUU,1684
18
- phylogenie/treesimulator/__init__.py,sha256=koyVZxKUeMtY1S2WUgqbeIthoGbdUmx9vP4svFi4a6U,459
19
- phylogenie/treesimulator/events.py,sha256=RfFodTxMAPUtI3ZkDmw5ekXCaRGnQ8Sy3tjR4oJ8CW4,9636
20
- phylogenie/treesimulator/gillespie.py,sha256=p2bjwc6iyaf4mMI-xregJwOnXkjT2tIhxXbcGncQ5b0,2293
21
- phylogenie/treesimulator/model.py,sha256=pPl4JYnBIHjWeSgpIypaDOes11tA4bFAGuMRArXbsVo,3577
22
- phylogenie/typeguards.py,sha256=WBOSJSaOC8VDtrYoA2w_AYEXTpyKdCfmsM29KaKXl3A,1350
23
- phylogenie/typings.py,sha256=O1X6lGKTjJ2YJz3ApQ-rYb_tEJNUIcHdUIeYlSM4s5o,500
24
- phylogenie-2.0.5.dist-info/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
25
- phylogenie-2.0.5.dist-info/METADATA,sha256=PHkxj9W4YYScWLa9oNke4sR_dXlT2-W8f6Z4ekuCgnI,5456
26
- phylogenie-2.0.5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
27
- phylogenie-2.0.5.dist-info/entry_points.txt,sha256=Rt6_usN0FkBX1ZfiqCirjMN9FKOgFLG8rydcQ8kugeE,51
28
- phylogenie-2.0.5.dist-info/RECORD,,