phylogenie 2.0.13__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +2 -6
- phylogenie/generators/__init__.py +0 -8
- phylogenie/generators/dataset.py +8 -2
- phylogenie/generators/factories.py +10 -10
- phylogenie/generators/trees.py +65 -59
- phylogenie/io.py +5 -2
- phylogenie/treesimulator/__init__.py +2 -1
- phylogenie/treesimulator/events.py +110 -72
- phylogenie/treesimulator/gillespie.py +80 -36
- phylogenie/treesimulator/model.py +21 -49
- {phylogenie-2.0.13.dist-info → phylogenie-2.1.0.dist-info}/METADATA +1 -1
- phylogenie-2.1.0.dist-info/RECORD +28 -0
- phylogenie-2.0.13.dist-info/RECORD +0 -28
- {phylogenie-2.0.13.dist-info → phylogenie-2.1.0.dist-info}/LICENSE.txt +0 -0
- {phylogenie-2.0.13.dist-info → phylogenie-2.1.0.dist-info}/WHEEL +0 -0
- {phylogenie-2.0.13.dist-info → phylogenie-2.1.0.dist-info}/entry_points.txt +0 -0
phylogenie/__init__.py
CHANGED
|
@@ -8,9 +8,6 @@ from phylogenie.generators import (
|
|
|
8
8
|
DatasetGeneratorConfig,
|
|
9
9
|
EpidemiologicalTreeDatasetGenerator,
|
|
10
10
|
FBDTreeDatasetGenerator,
|
|
11
|
-
SkylineMatrixModel,
|
|
12
|
-
SkylineParameterModel,
|
|
13
|
-
SkylineVectorModel,
|
|
14
11
|
TreeDatasetGeneratorConfig,
|
|
15
12
|
)
|
|
16
13
|
from phylogenie.io import load_fasta, load_newick
|
|
@@ -30,6 +27,7 @@ from phylogenie.skyline import (
|
|
|
30
27
|
from phylogenie.tree import Tree
|
|
31
28
|
from phylogenie.treesimulator import (
|
|
32
29
|
Event,
|
|
30
|
+
generate_trees,
|
|
33
31
|
get_BD_events,
|
|
34
32
|
get_BDEI_events,
|
|
35
33
|
get_BDSS_events,
|
|
@@ -51,16 +49,13 @@ __all__ = [
|
|
|
51
49
|
"FBDTreeDatasetGenerator",
|
|
52
50
|
"SkylineMatrix",
|
|
53
51
|
"SkylineMatrixCoercible",
|
|
54
|
-
"SkylineMatrixModel",
|
|
55
52
|
"skyline_matrix",
|
|
56
53
|
"SkylineParameter",
|
|
57
54
|
"SkylineParameterLike",
|
|
58
|
-
"SkylineParameterModel",
|
|
59
55
|
"skyline_parameter",
|
|
60
56
|
"SkylineVector",
|
|
61
57
|
"SkylineVectorCoercible",
|
|
62
58
|
"SkylineVectorLike",
|
|
63
|
-
"SkylineVectorModel",
|
|
64
59
|
"skyline_vector",
|
|
65
60
|
"Tree",
|
|
66
61
|
"TreeDatasetGeneratorConfig",
|
|
@@ -71,6 +66,7 @@ __all__ = [
|
|
|
71
66
|
"get_canonical_events",
|
|
72
67
|
"get_epidemiological_events",
|
|
73
68
|
"get_FBD_events",
|
|
69
|
+
"generate_trees",
|
|
74
70
|
"simulate_tree",
|
|
75
71
|
"load_fasta",
|
|
76
72
|
"load_newick",
|
|
@@ -3,11 +3,6 @@ from typing import Annotated
|
|
|
3
3
|
from pydantic import Field
|
|
4
4
|
|
|
5
5
|
from phylogenie.generators.alisim import AliSimDatasetGenerator
|
|
6
|
-
from phylogenie.generators.configs import (
|
|
7
|
-
SkylineMatrixModel,
|
|
8
|
-
SkylineParameterModel,
|
|
9
|
-
SkylineVectorModel,
|
|
10
|
-
)
|
|
11
6
|
from phylogenie.generators.dataset import DatasetGenerator
|
|
12
7
|
from phylogenie.generators.trees import (
|
|
13
8
|
BDEITreeDatasetGenerator,
|
|
@@ -34,7 +29,4 @@ __all__ = [
|
|
|
34
29
|
"BDTreeDatasetGenerator",
|
|
35
30
|
"BDEITreeDatasetGenerator",
|
|
36
31
|
"BDSSTreeDatasetGenerator",
|
|
37
|
-
"SkylineMatrixModel",
|
|
38
|
-
"SkylineParameterModel",
|
|
39
|
-
"SkylineVectorModel",
|
|
40
32
|
]
|
phylogenie/generators/dataset.py
CHANGED
|
@@ -11,6 +11,7 @@ from numpy.random import Generator, default_rng
|
|
|
11
11
|
from tqdm import tqdm
|
|
12
12
|
|
|
13
13
|
import phylogenie.generators.configs as cfg
|
|
14
|
+
from phylogenie.generators.factories import eval_expression
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class DataType(str, Enum):
|
|
@@ -56,18 +57,23 @@ class DatasetGenerator(ABC, cfg.StrictBaseModel):
|
|
|
56
57
|
if self.context is not None:
|
|
57
58
|
for d, (k, v) in product(data, self.context.items()):
|
|
58
59
|
args = v.model_extra if v.model_extra is not None else {}
|
|
60
|
+
for arg_name, arg_value in args.items():
|
|
61
|
+
if isinstance(arg_value, str):
|
|
62
|
+
args[arg_name] = eval_expression(arg_value, d)
|
|
59
63
|
d[k] = np.array(getattr(rng, v.type)(**args)).tolist()
|
|
60
64
|
df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
|
|
61
65
|
df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
|
|
62
66
|
|
|
63
|
-
joblib.Parallel(n_jobs=self.n_jobs)(
|
|
67
|
+
jobs = joblib.Parallel(n_jobs=self.n_jobs, return_as="generator_unordered")(
|
|
64
68
|
joblib.delayed(self.generate_one)(
|
|
65
69
|
filename=os.path.join(data_dir, str(i)),
|
|
66
70
|
data=data[i],
|
|
67
71
|
seed=int(rng.integers(2**32)),
|
|
68
72
|
)
|
|
69
|
-
for i in
|
|
73
|
+
for i in range(n_samples)
|
|
70
74
|
)
|
|
75
|
+
for _ in tqdm(jobs, total=n_samples, desc=f"Generating {data_dir}..."):
|
|
76
|
+
pass
|
|
71
77
|
|
|
72
78
|
def generate(self) -> None:
|
|
73
79
|
rng = default_rng(self.seed)
|
|
@@ -16,7 +16,7 @@ from phylogenie.skyline import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
20
20
|
return np.array(
|
|
21
21
|
eval(
|
|
22
22
|
expression,
|
|
@@ -31,7 +31,7 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
|
|
|
31
31
|
|
|
32
32
|
def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
|
|
33
33
|
if isinstance(x, str):
|
|
34
|
-
e =
|
|
34
|
+
e = eval_expression(x, data)
|
|
35
35
|
if isinstance(e, int):
|
|
36
36
|
return e
|
|
37
37
|
raise ValueError(
|
|
@@ -42,7 +42,7 @@ def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
|
|
|
42
42
|
|
|
43
43
|
def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
|
|
44
44
|
if isinstance(x, str):
|
|
45
|
-
e =
|
|
45
|
+
e = eval_expression(x, data)
|
|
46
46
|
if isinstance(e, pgt.Scalar):
|
|
47
47
|
return e
|
|
48
48
|
raise ValueError(
|
|
@@ -53,7 +53,7 @@ def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
|
|
|
53
53
|
|
|
54
54
|
def many_scalars(x: cfg.ManyScalars, data: dict[str, Any]) -> pgt.ManyScalars:
|
|
55
55
|
if isinstance(x, str):
|
|
56
|
-
e =
|
|
56
|
+
e = eval_expression(x, data)
|
|
57
57
|
if tg.is_many_scalars(e):
|
|
58
58
|
return e
|
|
59
59
|
raise ValueError(
|
|
@@ -66,7 +66,7 @@ def one_or_many_scalars(
|
|
|
66
66
|
x: cfg.OneOrManyScalars, data: dict[str, Any]
|
|
67
67
|
) -> pgt.OneOrManyScalars:
|
|
68
68
|
if isinstance(x, str):
|
|
69
|
-
e =
|
|
69
|
+
e = eval_expression(x, data)
|
|
70
70
|
if tg.is_one_or_many_scalars(e):
|
|
71
71
|
return e
|
|
72
72
|
raise ValueError(
|
|
@@ -92,7 +92,7 @@ def skyline_vector(
|
|
|
92
92
|
x: cfg.SkylineVector, data: dict[str, Any]
|
|
93
93
|
) -> SkylineVectorCoercible:
|
|
94
94
|
if isinstance(x, str):
|
|
95
|
-
e =
|
|
95
|
+
e = eval_expression(x, data)
|
|
96
96
|
if tg.is_one_or_many_scalars(e):
|
|
97
97
|
return e
|
|
98
98
|
raise ValueError(
|
|
@@ -107,7 +107,7 @@ def skyline_vector(
|
|
|
107
107
|
|
|
108
108
|
change_times = many_scalars(x.change_times, data)
|
|
109
109
|
if isinstance(x.value, str):
|
|
110
|
-
e =
|
|
110
|
+
e = eval_expression(x.value, data)
|
|
111
111
|
if tg.is_many_one_or_many_scalars(e):
|
|
112
112
|
value = e
|
|
113
113
|
else:
|
|
@@ -135,7 +135,7 @@ def one_or_many_2D_scalars(
|
|
|
135
135
|
x: cfg.OneOrMany2DScalars, data: dict[str, Any]
|
|
136
136
|
) -> pgt.OneOrMany2DScalars:
|
|
137
137
|
if isinstance(x, str):
|
|
138
|
-
e =
|
|
138
|
+
e = eval_expression(x, data)
|
|
139
139
|
if tg.is_one_or_many_2D_scalars(e):
|
|
140
140
|
return e
|
|
141
141
|
raise ValueError(
|
|
@@ -153,7 +153,7 @@ def skyline_matrix(
|
|
|
153
153
|
return None
|
|
154
154
|
|
|
155
155
|
if isinstance(x, str):
|
|
156
|
-
e =
|
|
156
|
+
e = eval_expression(x, data)
|
|
157
157
|
if tg.is_one_or_many_2D_scalars(e):
|
|
158
158
|
return e
|
|
159
159
|
raise ValueError(
|
|
@@ -168,7 +168,7 @@ def skyline_matrix(
|
|
|
168
168
|
|
|
169
169
|
change_times = many_scalars(x.change_times, data)
|
|
170
170
|
if isinstance(x.value, str):
|
|
171
|
-
e =
|
|
171
|
+
e = eval_expression(x.value, data)
|
|
172
172
|
if tg.is_many_one_or_many_2D_scalars(e):
|
|
173
173
|
value = e
|
|
174
174
|
else:
|
phylogenie/generators/trees.py
CHANGED
|
@@ -23,6 +23,7 @@ from phylogenie.treesimulator import (
|
|
|
23
23
|
get_BDEI_events,
|
|
24
24
|
get_BDSS_events,
|
|
25
25
|
get_canonical_events,
|
|
26
|
+
get_contact_tracing_events,
|
|
26
27
|
get_epidemiological_events,
|
|
27
28
|
get_FBD_events,
|
|
28
29
|
simulate_tree,
|
|
@@ -41,16 +42,14 @@ class ParameterizationType(str, Enum):
|
|
|
41
42
|
class TreeDatasetGenerator(DatasetGenerator):
|
|
42
43
|
data_type: Literal[DataType.TREES] = DataType.TREES
|
|
43
44
|
min_tips: cfg.Integer = 1
|
|
44
|
-
max_tips: cfg.Integer
|
|
45
|
+
max_tips: cfg.Integer = 2**32
|
|
45
46
|
max_time: cfg.Scalar = np.inf
|
|
46
47
|
init_state: str | None = None
|
|
47
48
|
sampling_probability_at_present: cfg.Scalar = 0.0
|
|
48
49
|
max_tries: int | None = None
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
sampling_rate_after_notification: cfg.SkylineParameter = np.inf
|
|
53
|
-
contacts_removal_probability: cfg.SkylineParameter = 1
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
54
53
|
|
|
55
54
|
def simulate_one(self, rng: Generator, data: dict[str, Any]) -> Tree | None:
|
|
56
55
|
events = self._get_events(data)
|
|
@@ -59,34 +58,19 @@ class TreeDatasetGenerator(DatasetGenerator):
|
|
|
59
58
|
if self.init_state is None
|
|
60
59
|
else self.init_state.format(**data)
|
|
61
60
|
)
|
|
62
|
-
max_tips = (
|
|
63
|
-
self.max_tips if self.max_tips is None else integer(self.max_tips, data)
|
|
64
|
-
)
|
|
65
61
|
return simulate_tree(
|
|
66
62
|
events=events,
|
|
67
63
|
min_tips=integer(self.min_tips, data),
|
|
68
|
-
max_tips=max_tips,
|
|
64
|
+
max_tips=integer(self.max_tips, data),
|
|
69
65
|
max_time=scalar(self.max_time, data),
|
|
70
66
|
init_state=init_state,
|
|
71
67
|
sampling_probability_at_present=scalar(
|
|
72
68
|
self.sampling_probability_at_present, data
|
|
73
69
|
),
|
|
74
|
-
notification_probability=scalar(self.notification_probability, data),
|
|
75
|
-
max_notified_contacts=integer(self.max_notified_contacts, data),
|
|
76
|
-
samplable_states_after_notification=self.samplable_states_after_notification,
|
|
77
|
-
sampling_rate_after_notification=skyline_parameter(
|
|
78
|
-
self.sampling_rate_after_notification, data
|
|
79
|
-
),
|
|
80
|
-
contacts_removal_probability=skyline_parameter(
|
|
81
|
-
self.contacts_removal_probability, data
|
|
82
|
-
),
|
|
83
70
|
max_tries=self.max_tries,
|
|
84
71
|
seed=int(rng.integers(2**32)),
|
|
85
72
|
)
|
|
86
73
|
|
|
87
|
-
@abstractmethod
|
|
88
|
-
def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
89
|
-
|
|
90
74
|
def _generate_one(
|
|
91
75
|
self, filename: str, rng: Generator, data: dict[str, Any]
|
|
92
76
|
) -> None:
|
|
@@ -101,9 +85,9 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
101
85
|
)
|
|
102
86
|
states: list[str]
|
|
103
87
|
sampling_rates: cfg.SkylineVector
|
|
88
|
+
remove_after_sampling: bool
|
|
104
89
|
birth_rates: cfg.SkylineVector = 0
|
|
105
90
|
death_rates: cfg.SkylineVector = 0
|
|
106
|
-
removal_probabilities: cfg.SkylineVector = 0
|
|
107
91
|
migration_rates: cfg.SkylineMatrix = None
|
|
108
92
|
birth_rates_among_states: cfg.SkylineMatrix = None
|
|
109
93
|
|
|
@@ -111,9 +95,9 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
111
95
|
return get_canonical_events(
|
|
112
96
|
states=self.states,
|
|
113
97
|
sampling_rates=skyline_vector(self.sampling_rates, data),
|
|
98
|
+
remove_after_sampling=self.remove_after_sampling,
|
|
114
99
|
birth_rates=skyline_vector(self.birth_rates, data),
|
|
115
100
|
death_rates=skyline_vector(self.death_rates, data),
|
|
116
|
-
removal_probabilities=skyline_vector(self.removal_probabilities, data),
|
|
117
101
|
migration_rates=skyline_matrix(self.migration_rates, data),
|
|
118
102
|
birth_rates_among_states=skyline_matrix(
|
|
119
103
|
self.birth_rates_among_states, data
|
|
@@ -121,19 +105,66 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
121
105
|
)
|
|
122
106
|
|
|
123
107
|
|
|
124
|
-
class
|
|
108
|
+
class FBDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
109
|
+
parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
|
|
110
|
+
states: list[str]
|
|
111
|
+
sampling_proportions: cfg.SkylineVector
|
|
112
|
+
diversification: cfg.SkylineVector = 0
|
|
113
|
+
turnover: cfg.SkylineVector = 0
|
|
114
|
+
migration_rates: cfg.SkylineMatrix = None
|
|
115
|
+
diversification_between_states: cfg.SkylineMatrix = None
|
|
116
|
+
|
|
117
|
+
def _get_events(self, data: dict[str, Any]) -> list[Event]:
|
|
118
|
+
return get_FBD_events(
|
|
119
|
+
states=self.states,
|
|
120
|
+
diversification=skyline_vector(self.diversification, data),
|
|
121
|
+
turnover=skyline_vector(self.turnover, data),
|
|
122
|
+
sampling_proportions=skyline_vector(self.sampling_proportions, data),
|
|
123
|
+
migration_rates=skyline_matrix(self.migration_rates, data),
|
|
124
|
+
diversification_between_states=skyline_matrix(
|
|
125
|
+
self.diversification_between_states, data
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class TreeDatasetGeneratorWithContactTracing(TreeDatasetGenerator):
|
|
131
|
+
max_notified_contacts: cfg.Integer = 1
|
|
132
|
+
notification_probability: cfg.SkylineParameter = 0.0
|
|
133
|
+
sampling_rate_after_notification: cfg.SkylineParameter = np.inf
|
|
134
|
+
samplable_states_after_notification: list[str] | None = None
|
|
135
|
+
|
|
136
|
+
@abstractmethod
|
|
137
|
+
def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
|
|
138
|
+
|
|
139
|
+
def _get_events(self, data: dict[str, Any]) -> list[Event]:
|
|
140
|
+
events = self._get_base_events(data)
|
|
141
|
+
if self.notification_probability:
|
|
142
|
+
events = get_contact_tracing_events(
|
|
143
|
+
events=events,
|
|
144
|
+
max_notified_contacts=integer(self.max_notified_contacts, data),
|
|
145
|
+
notification_probability=skyline_parameter(
|
|
146
|
+
self.notification_probability, data
|
|
147
|
+
),
|
|
148
|
+
sampling_rate_after_notification=skyline_parameter(
|
|
149
|
+
self.sampling_rate_after_notification, data
|
|
150
|
+
),
|
|
151
|
+
samplable_states_after_notification=self.samplable_states_after_notification,
|
|
152
|
+
)
|
|
153
|
+
return events
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
125
157
|
parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
|
|
126
158
|
ParameterizationType.EPIDEMIOLOGICAL
|
|
127
159
|
)
|
|
128
160
|
states: list[str]
|
|
161
|
+
sampling_proportions: cfg.SkylineVector
|
|
129
162
|
reproduction_numbers: cfg.SkylineVector = 0
|
|
130
163
|
become_uninfectious_rates: cfg.SkylineVector = 0
|
|
131
|
-
sampling_proportions: cfg.SkylineVector = 1
|
|
132
|
-
removal_probabilities: cfg.SkylineVector = 1
|
|
133
164
|
migration_rates: cfg.SkylineMatrix = None
|
|
134
165
|
reproduction_numbers_among_states: cfg.SkylineMatrix = None
|
|
135
166
|
|
|
136
|
-
def
|
|
167
|
+
def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
|
|
137
168
|
return get_epidemiological_events(
|
|
138
169
|
states=self.states,
|
|
139
170
|
reproduction_numbers=skyline_vector(self.reproduction_numbers, data),
|
|
@@ -141,7 +172,6 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
141
172
|
self.become_uninfectious_rates, data
|
|
142
173
|
),
|
|
143
174
|
sampling_proportions=skyline_vector(self.sampling_proportions, data),
|
|
144
|
-
removal_probabilities=skyline_vector(self.removal_probabilities, data),
|
|
145
175
|
migration_rates=skyline_matrix(self.migration_rates, data),
|
|
146
176
|
reproduction_numbers_among_states=skyline_matrix(
|
|
147
177
|
self.reproduction_numbers_among_states, data
|
|
@@ -149,37 +179,13 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
149
179
|
)
|
|
150
180
|
|
|
151
181
|
|
|
152
|
-
class
|
|
153
|
-
parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
|
|
154
|
-
states: list[str]
|
|
155
|
-
diversification: cfg.SkylineVector = 0
|
|
156
|
-
turnover: cfg.SkylineVector = 0
|
|
157
|
-
sampling_proportions: cfg.SkylineVector = 1
|
|
158
|
-
removal_probabilities: cfg.SkylineVector = 0
|
|
159
|
-
migration_rates: cfg.SkylineMatrix = None
|
|
160
|
-
diversification_between_types: cfg.SkylineMatrix = None
|
|
161
|
-
|
|
162
|
-
def _get_events(self, data: dict[str, Any]) -> list[Event]:
|
|
163
|
-
return get_FBD_events(
|
|
164
|
-
states=self.states,
|
|
165
|
-
diversification=skyline_vector(self.diversification, data),
|
|
166
|
-
turnover=skyline_vector(self.turnover, data),
|
|
167
|
-
sampling_proportions=skyline_vector(self.sampling_proportions, data),
|
|
168
|
-
removal_probabilities=skyline_vector(self.removal_probabilities, data),
|
|
169
|
-
migration_rates=skyline_matrix(self.migration_rates, data),
|
|
170
|
-
diversification_between_types=skyline_matrix(
|
|
171
|
-
self.diversification_between_types, data
|
|
172
|
-
),
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
class BDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
182
|
+
class BDTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
177
183
|
parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
|
|
178
184
|
reproduction_number: cfg.SkylineParameter
|
|
179
185
|
infectious_period: cfg.SkylineParameter
|
|
180
186
|
sampling_proportion: cfg.SkylineParameter = 1
|
|
181
187
|
|
|
182
|
-
def
|
|
188
|
+
def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
|
|
183
189
|
return get_BD_events(
|
|
184
190
|
reproduction_number=skyline_parameter(self.reproduction_number, data),
|
|
185
191
|
infectious_period=skyline_parameter(self.infectious_period, data),
|
|
@@ -187,14 +193,14 @@ class BDTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
187
193
|
)
|
|
188
194
|
|
|
189
195
|
|
|
190
|
-
class BDEITreeDatasetGenerator(
|
|
196
|
+
class BDEITreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
191
197
|
parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
|
|
192
198
|
reproduction_number: cfg.SkylineParameter
|
|
193
199
|
infectious_period: cfg.SkylineParameter
|
|
194
200
|
incubation_period: cfg.SkylineParameter
|
|
195
201
|
sampling_proportion: cfg.SkylineParameter = 1
|
|
196
202
|
|
|
197
|
-
def
|
|
203
|
+
def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
|
|
198
204
|
return get_BDEI_events(
|
|
199
205
|
reproduction_number=skyline_parameter(self.reproduction_number, data),
|
|
200
206
|
infectious_period=skyline_parameter(self.infectious_period, data),
|
|
@@ -203,7 +209,7 @@ class BDEITreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
203
209
|
)
|
|
204
210
|
|
|
205
211
|
|
|
206
|
-
class BDSSTreeDatasetGenerator(
|
|
212
|
+
class BDSSTreeDatasetGenerator(TreeDatasetGeneratorWithContactTracing):
|
|
207
213
|
parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
|
|
208
214
|
reproduction_number: cfg.SkylineParameter
|
|
209
215
|
infectious_period: cfg.SkylineParameter
|
|
@@ -211,7 +217,7 @@ class BDSSTreeDatasetGenerator(TreeDatasetGenerator):
|
|
|
211
217
|
superspreaders_proportion: cfg.SkylineParameter
|
|
212
218
|
sampling_proportion: cfg.SkylineParameter = 1
|
|
213
219
|
|
|
214
|
-
def
|
|
220
|
+
def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
|
|
215
221
|
return get_BDSS_events(
|
|
216
222
|
reproduction_number=skyline_parameter(self.reproduction_number, data),
|
|
217
223
|
infectious_period=skyline_parameter(self.infectious_period, data),
|
phylogenie/io.py
CHANGED
|
@@ -66,9 +66,12 @@ def _to_newick(tree: Tree) -> str:
|
|
|
66
66
|
return newick
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
def dump_newick(
|
|
69
|
+
def dump_newick(trees: Tree | list[Tree], filepath: str) -> None:
|
|
70
|
+
if isinstance(trees, Tree):
|
|
71
|
+
trees = [trees]
|
|
70
72
|
with open(filepath, "w") as file:
|
|
71
|
-
|
|
73
|
+
for t in trees:
|
|
74
|
+
file.write(_to_newick(t) + ";\n")
|
|
72
75
|
|
|
73
76
|
|
|
74
77
|
def load_fasta(
|
|
@@ -8,7 +8,7 @@ from phylogenie.treesimulator.events import (
|
|
|
8
8
|
get_epidemiological_events,
|
|
9
9
|
get_FBD_events,
|
|
10
10
|
)
|
|
11
|
-
from phylogenie.treesimulator.gillespie import simulate_tree
|
|
11
|
+
from phylogenie.treesimulator.gillespie import generate_trees, simulate_tree
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
14
|
"Event",
|
|
@@ -19,5 +19,6 @@ __all__ = [
|
|
|
19
19
|
"get_contact_tracing_events",
|
|
20
20
|
"get_epidemiological_events",
|
|
21
21
|
"get_FBD_events",
|
|
22
|
+
"generate_trees",
|
|
22
23
|
"simulate_tree",
|
|
23
24
|
]
|
|
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.random import Generator
|
|
6
7
|
|
|
7
8
|
from phylogenie.skyline import (
|
|
8
9
|
SkylineMatrixCoercible,
|
|
@@ -13,11 +14,20 @@ from phylogenie.skyline import (
|
|
|
13
14
|
skyline_parameter,
|
|
14
15
|
skyline_vector,
|
|
15
16
|
)
|
|
16
|
-
from phylogenie.treesimulator.model import Model
|
|
17
|
+
from phylogenie.treesimulator.model import Model
|
|
17
18
|
|
|
18
19
|
INFECTIOUS_STATE = "I"
|
|
19
20
|
EXPOSED_STATE = "E"
|
|
20
21
|
SUPERSPREADER_STATE = "S"
|
|
22
|
+
CT_POSTFIX = "-CT"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_CT_state(state: str) -> str:
|
|
26
|
+
return f"{state}{CT_POSTFIX}"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_CT_state(state: str) -> bool:
|
|
30
|
+
return state.endswith(CT_POSTFIX)
|
|
21
31
|
|
|
22
32
|
|
|
23
33
|
@dataclass
|
|
@@ -25,6 +35,9 @@ class Event(ABC):
|
|
|
25
35
|
rate: SkylineParameter
|
|
26
36
|
state: str
|
|
27
37
|
|
|
38
|
+
def draw_individual(self, model: Model, rng: Generator) -> int:
|
|
39
|
+
return rng.choice(model.get_population(self.state))
|
|
40
|
+
|
|
28
41
|
def get_propensity(self, model: Model, time: float) -> float:
|
|
29
42
|
n_individuals = model.count_individuals(self.state)
|
|
30
43
|
rate = self.rate.get_value_at_time(time)
|
|
@@ -33,48 +46,65 @@ class Event(ABC):
|
|
|
33
46
|
return rate * n_individuals
|
|
34
47
|
|
|
35
48
|
@abstractmethod
|
|
36
|
-
def apply(self, model: Model, time: float) -> None: ...
|
|
49
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None: ...
|
|
37
50
|
|
|
38
51
|
|
|
39
52
|
@dataclass
|
|
40
|
-
class
|
|
53
|
+
class Birth(Event):
|
|
41
54
|
child_state: str
|
|
42
55
|
|
|
43
|
-
def apply(self, model: Model, time: float) -> None:
|
|
44
|
-
individual =
|
|
56
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
57
|
+
individual = self.draw_individual(model, rng)
|
|
45
58
|
model.birth_from(individual, self.child_state, time)
|
|
46
59
|
|
|
47
60
|
|
|
48
|
-
class
|
|
49
|
-
def apply(self, model: Model, time: float) -> None:
|
|
50
|
-
individual =
|
|
61
|
+
class Death(Event):
|
|
62
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
63
|
+
individual = self.draw_individual(model, rng)
|
|
51
64
|
model.remove(individual, time)
|
|
52
65
|
|
|
53
66
|
|
|
54
67
|
@dataclass
|
|
55
|
-
class
|
|
68
|
+
class Migration(Event):
|
|
56
69
|
target_state: str
|
|
57
70
|
|
|
58
|
-
def apply(self, model: Model, time: float) -> None:
|
|
59
|
-
individual =
|
|
71
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
72
|
+
individual = self.draw_individual(model, rng)
|
|
60
73
|
model.migrate(individual, self.target_state, time)
|
|
61
74
|
|
|
62
75
|
|
|
63
76
|
@dataclass
|
|
64
|
-
class
|
|
65
|
-
|
|
77
|
+
class Sampling(Event):
|
|
78
|
+
removal: bool
|
|
79
|
+
|
|
80
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
81
|
+
individual = self.draw_individual(model, rng)
|
|
82
|
+
model.sample(individual, time, self.removal)
|
|
66
83
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class SamplingWithContactTracing(Event):
|
|
87
|
+
max_notified_contacts: int
|
|
88
|
+
notification_probability: SkylineParameter
|
|
89
|
+
|
|
90
|
+
def apply(self, model: Model, time: float, rng: Generator) -> None:
|
|
91
|
+
individual = self.draw_individual(model, rng)
|
|
92
|
+
model.sample(individual, time, True)
|
|
93
|
+
population = model.get_population()
|
|
94
|
+
for contact in model.get_lineage(individual)[-self.max_notified_contacts :]:
|
|
95
|
+
if contact in population:
|
|
96
|
+
state = model.get_state(contact)
|
|
97
|
+
p = self.notification_probability.get_value_at_time(time)
|
|
98
|
+
if not _is_CT_state(state) and rng.random() < p:
|
|
99
|
+
model.migrate(contact, _get_CT_state(state), time)
|
|
70
100
|
|
|
71
101
|
|
|
72
102
|
def get_canonical_events(
|
|
73
103
|
states: Sequence[str],
|
|
74
104
|
sampling_rates: SkylineVectorCoercible,
|
|
105
|
+
remove_after_sampling: bool,
|
|
75
106
|
birth_rates: SkylineVectorCoercible = 0,
|
|
76
107
|
death_rates: SkylineVectorCoercible = 0,
|
|
77
|
-
removal_probabilities: SkylineVectorCoercible = 0,
|
|
78
108
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
79
109
|
birth_rates_among_states: SkylineMatrixCoercible | None = None,
|
|
80
110
|
) -> list[Event]:
|
|
@@ -83,102 +113,94 @@ def get_canonical_events(
|
|
|
83
113
|
birth_rates = skyline_vector(birth_rates, N)
|
|
84
114
|
death_rates = skyline_vector(death_rates, N)
|
|
85
115
|
sampling_rates = skyline_vector(sampling_rates, N)
|
|
86
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
87
116
|
|
|
88
117
|
events: list[Event] = []
|
|
89
118
|
for i, state in enumerate(states):
|
|
90
|
-
events.append(
|
|
91
|
-
events.append(
|
|
92
|
-
events.append(
|
|
119
|
+
events.append(Birth(birth_rates[i], state, state))
|
|
120
|
+
events.append(Death(death_rates[i], state))
|
|
121
|
+
events.append(Sampling(sampling_rates[i], state, remove_after_sampling))
|
|
93
122
|
|
|
94
123
|
if migration_rates is not None:
|
|
95
124
|
migration_rates = skyline_matrix(migration_rates, N, N - 1)
|
|
96
125
|
for i, state in enumerate(states):
|
|
97
126
|
for j, other_state in enumerate([s for s in states if s != state]):
|
|
98
|
-
events.append(
|
|
127
|
+
events.append(Migration(migration_rates[i, j], state, other_state))
|
|
99
128
|
|
|
100
129
|
if birth_rates_among_states is not None:
|
|
101
130
|
birth_rates_among_states = skyline_matrix(birth_rates_among_states, N, N - 1)
|
|
102
131
|
for i, state in enumerate(states):
|
|
103
132
|
for j, other_state in enumerate([s for s in states if s != state]):
|
|
104
|
-
events.append(
|
|
105
|
-
BirthEvent(birth_rates_among_states[i, j], state, other_state)
|
|
106
|
-
)
|
|
133
|
+
events.append(Birth(birth_rates_among_states[i, j], state, other_state))
|
|
107
134
|
|
|
108
135
|
return [event for event in events if event.rate]
|
|
109
136
|
|
|
110
137
|
|
|
111
|
-
def
|
|
138
|
+
def get_FBD_events(
|
|
112
139
|
states: Sequence[str],
|
|
113
140
|
sampling_proportions: SkylineVectorCoercible = 1,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
removal_probabilities: SkylineVectorCoercible = 1,
|
|
141
|
+
diversification: SkylineVectorCoercible = 0,
|
|
142
|
+
turnover: SkylineVectorCoercible = 0,
|
|
117
143
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
118
|
-
|
|
144
|
+
diversification_between_states: SkylineMatrixCoercible | None = None,
|
|
119
145
|
) -> list[Event]:
|
|
120
146
|
N = len(states)
|
|
121
147
|
|
|
122
|
-
|
|
123
|
-
|
|
148
|
+
diversification = skyline_vector(diversification, N)
|
|
149
|
+
turnover = skyline_vector(turnover, N)
|
|
124
150
|
sampling_proportions = skyline_vector(sampling_proportions, N)
|
|
125
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
126
151
|
|
|
127
|
-
birth_rates =
|
|
128
|
-
|
|
129
|
-
|
|
152
|
+
birth_rates = diversification / (1 - turnover)
|
|
153
|
+
death_rates = turnover * birth_rates
|
|
154
|
+
sampling_rates = sampling_proportions * death_rates
|
|
130
155
|
birth_rates_among_states = (
|
|
131
|
-
(
|
|
132
|
-
|
|
133
|
-
* become_uninfectious_rates
|
|
134
|
-
)
|
|
135
|
-
if reproduction_numbers_among_states is not None
|
|
156
|
+
(skyline_matrix(diversification_between_states, N, N - 1) + death_rates)
|
|
157
|
+
if diversification_between_states is not None
|
|
136
158
|
else None
|
|
137
159
|
)
|
|
138
160
|
|
|
139
161
|
return get_canonical_events(
|
|
140
162
|
states=states,
|
|
163
|
+
sampling_rates=sampling_rates,
|
|
164
|
+
remove_after_sampling=False,
|
|
141
165
|
birth_rates=birth_rates,
|
|
142
166
|
death_rates=death_rates,
|
|
143
|
-
sampling_rates=sampling_rates,
|
|
144
|
-
removal_probabilities=removal_probabilities,
|
|
145
167
|
migration_rates=migration_rates,
|
|
146
168
|
birth_rates_among_states=birth_rates_among_states,
|
|
147
169
|
)
|
|
148
170
|
|
|
149
171
|
|
|
150
|
-
def
|
|
172
|
+
def get_epidemiological_events(
|
|
151
173
|
states: Sequence[str],
|
|
152
|
-
diversification: SkylineVectorCoercible = 0,
|
|
153
|
-
turnover: SkylineVectorCoercible = 0,
|
|
154
174
|
sampling_proportions: SkylineVectorCoercible = 1,
|
|
155
|
-
|
|
175
|
+
reproduction_numbers: SkylineVectorCoercible = 0,
|
|
176
|
+
become_uninfectious_rates: SkylineVectorCoercible = 0,
|
|
156
177
|
migration_rates: SkylineMatrixCoercible | None = None,
|
|
157
|
-
|
|
178
|
+
reproduction_numbers_among_states: SkylineMatrixCoercible | None = None,
|
|
158
179
|
) -> list[Event]:
|
|
159
180
|
N = len(states)
|
|
160
181
|
|
|
161
|
-
|
|
162
|
-
|
|
182
|
+
reproduction_numbers = skyline_vector(reproduction_numbers, N)
|
|
183
|
+
become_uninfectious_rates = skyline_vector(become_uninfectious_rates, N)
|
|
163
184
|
sampling_proportions = skyline_vector(sampling_proportions, N)
|
|
164
|
-
removal_probabilities = skyline_vector(removal_probabilities, N)
|
|
165
185
|
|
|
166
|
-
birth_rates =
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
sampling_rates = sampling_proportions * death_rates / sampling_rates_dividend
|
|
186
|
+
birth_rates = reproduction_numbers * become_uninfectious_rates
|
|
187
|
+
sampling_rates = become_uninfectious_rates * sampling_proportions
|
|
188
|
+
death_rates = become_uninfectious_rates - sampling_rates
|
|
170
189
|
birth_rates_among_states = (
|
|
171
|
-
(
|
|
172
|
-
|
|
190
|
+
(
|
|
191
|
+
skyline_matrix(reproduction_numbers_among_states, N, N - 1)
|
|
192
|
+
* become_uninfectious_rates
|
|
193
|
+
)
|
|
194
|
+
if reproduction_numbers_among_states is not None
|
|
173
195
|
else None
|
|
174
196
|
)
|
|
175
197
|
|
|
176
198
|
return get_canonical_events(
|
|
177
199
|
states=states,
|
|
200
|
+
sampling_rates=sampling_rates,
|
|
201
|
+
remove_after_sampling=True,
|
|
178
202
|
birth_rates=birth_rates,
|
|
179
203
|
death_rates=death_rates,
|
|
180
|
-
sampling_rates=sampling_rates,
|
|
181
|
-
removal_probabilities=removal_probabilities,
|
|
182
204
|
migration_rates=migration_rates,
|
|
183
205
|
birth_rates_among_states=birth_rates_among_states,
|
|
184
206
|
)
|
|
@@ -236,35 +258,51 @@ def get_BDSS_events(
|
|
|
236
258
|
|
|
237
259
|
def get_contact_tracing_events(
|
|
238
260
|
events: Sequence[Event],
|
|
239
|
-
|
|
261
|
+
max_notified_contacts: int = 1,
|
|
262
|
+
notification_probability: SkylineParameterLike = 1,
|
|
240
263
|
sampling_rate_after_notification: SkylineParameterLike = np.inf,
|
|
241
|
-
|
|
264
|
+
samplable_states_after_notification: Sequence[str] | None = None,
|
|
242
265
|
) -> list[Event]:
|
|
243
|
-
ct_events =
|
|
266
|
+
ct_events = [e for e in events if not isinstance(e, Sampling)]
|
|
267
|
+
|
|
244
268
|
for event in events:
|
|
245
|
-
if isinstance(event,
|
|
269
|
+
if isinstance(event, Migration):
|
|
246
270
|
ct_events.append(
|
|
247
|
-
|
|
271
|
+
Migration(
|
|
248
272
|
event.rate,
|
|
249
|
-
|
|
250
|
-
|
|
273
|
+
_get_CT_state(event.state),
|
|
274
|
+
_get_CT_state(event.target_state),
|
|
251
275
|
)
|
|
252
276
|
)
|
|
253
|
-
elif isinstance(event,
|
|
277
|
+
elif isinstance(event, Birth):
|
|
254
278
|
ct_events.append(
|
|
255
|
-
|
|
279
|
+
Birth(event.rate, _get_CT_state(event.state), event.child_state)
|
|
280
|
+
)
|
|
281
|
+
elif isinstance(event, Sampling):
|
|
282
|
+
if not event.removal:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"Contact tracing requires removal to be set for all sampling events."
|
|
285
|
+
)
|
|
286
|
+
ct_events.append(
|
|
287
|
+
SamplingWithContactTracing(
|
|
288
|
+
event.rate,
|
|
289
|
+
event.state,
|
|
290
|
+
max_notified_contacts,
|
|
291
|
+
skyline_parameter(notification_probability),
|
|
292
|
+
)
|
|
256
293
|
)
|
|
257
294
|
|
|
258
295
|
for state in (
|
|
259
296
|
samplable_states_after_notification
|
|
260
297
|
if samplable_states_after_notification is not None
|
|
261
|
-
else
|
|
298
|
+
else {e.state for e in events}
|
|
262
299
|
):
|
|
263
300
|
ct_events.append(
|
|
264
|
-
|
|
301
|
+
SamplingWithContactTracing(
|
|
265
302
|
skyline_parameter(sampling_rate_after_notification),
|
|
266
|
-
|
|
267
|
-
|
|
303
|
+
_get_CT_state(state),
|
|
304
|
+
max_notified_contacts,
|
|
305
|
+
skyline_parameter(notification_probability),
|
|
268
306
|
)
|
|
269
307
|
)
|
|
270
308
|
|
|
@@ -1,84 +1,128 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from collections.abc import Sequence
|
|
2
3
|
|
|
4
|
+
import joblib
|
|
3
5
|
import numpy as np
|
|
4
6
|
from numpy.random import default_rng
|
|
7
|
+
from tqdm import tqdm
|
|
5
8
|
|
|
6
|
-
from phylogenie.
|
|
9
|
+
from phylogenie.io import dump_newick
|
|
7
10
|
from phylogenie.tree import Tree
|
|
8
|
-
from phylogenie.treesimulator.events import Event
|
|
11
|
+
from phylogenie.treesimulator.events import Event
|
|
9
12
|
from phylogenie.treesimulator.model import Model
|
|
10
13
|
|
|
11
14
|
|
|
12
15
|
def simulate_tree(
|
|
13
16
|
events: Sequence[Event],
|
|
14
17
|
min_tips: int = 1,
|
|
15
|
-
max_tips: int
|
|
18
|
+
max_tips: int = 2**32,
|
|
16
19
|
max_time: float = np.inf,
|
|
17
20
|
init_state: str | None = None,
|
|
18
21
|
sampling_probability_at_present: float = 0.0,
|
|
19
|
-
notification_probability: float = 0,
|
|
20
|
-
max_notified_contacts: int = 1,
|
|
21
|
-
samplable_states_after_notification: Sequence[str] | None = None,
|
|
22
|
-
sampling_rate_after_notification: SkylineParameterLike = np.inf,
|
|
23
|
-
contacts_removal_probability: SkylineParameterLike = 1,
|
|
24
22
|
max_tries: int | None = None,
|
|
25
23
|
seed: int | None = None,
|
|
26
24
|
) -> Tree | None:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
if max_tips is None and max_time == np.inf:
|
|
30
|
-
raise ValueError("Either max_tips or max_time must be specified.")
|
|
25
|
+
if max_time == np.inf and max_tips == 2**32:
|
|
26
|
+
raise ValueError("Either max_time or max_tips must be specified.")
|
|
31
27
|
|
|
32
|
-
|
|
28
|
+
if max_time == np.inf and sampling_probability_at_present:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"sampling_probability_at_present cannot be set when max_time is infinite."
|
|
31
|
+
)
|
|
33
32
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
sampling_rate_after_notification,
|
|
39
|
-
contacts_removal_probability,
|
|
33
|
+
states = {e.state for e in events}
|
|
34
|
+
if init_state is None and len(states) > 1:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Init state must be provided for models with more than one state."
|
|
40
37
|
)
|
|
38
|
+
elif init_state is None:
|
|
39
|
+
(init_state,) = states
|
|
40
|
+
elif init_state not in states:
|
|
41
|
+
raise ValueError(f"Init state {init_state} not found in event states: {states}")
|
|
41
42
|
|
|
43
|
+
rng = default_rng(seed)
|
|
42
44
|
n_tries = 0
|
|
43
45
|
while max_tries is None or n_tries < max_tries:
|
|
44
|
-
|
|
45
|
-
model = Model(root_state, max_notified_contacts, notification_probability, rng)
|
|
46
|
+
model = Model(init_state)
|
|
46
47
|
current_time = 0.0
|
|
47
48
|
change_times = sorted(set(t for e in events for t in e.rate.change_times))
|
|
48
49
|
next_change_time = change_times.pop(0) if change_times else np.inf
|
|
49
|
-
n_tips = None if max_tips is None else rng.integers(min_tips, max_tips + 1)
|
|
50
50
|
|
|
51
|
-
|
|
51
|
+
target_n_tips = rng.integers(min_tips, max_tips) if max_time == np.inf else None
|
|
52
|
+
while current_time < max_time:
|
|
52
53
|
rates = [e.get_propensity(model, current_time) for e in events]
|
|
53
54
|
|
|
54
55
|
instantaneous_events = [e for e, r in zip(events, rates) if r == np.inf]
|
|
55
56
|
if instantaneous_events:
|
|
56
57
|
event = instantaneous_events[rng.integers(len(instantaneous_events))]
|
|
57
|
-
event.apply(model, current_time)
|
|
58
|
+
event.apply(model, current_time, rng)
|
|
58
59
|
continue
|
|
59
60
|
|
|
60
|
-
if
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
61
|
+
if (
|
|
62
|
+
not any(rates)
|
|
63
|
+
or model.n_sampled > max_tips
|
|
64
|
+
or target_n_tips is not None
|
|
65
|
+
and model.n_sampled >= target_n_tips
|
|
66
|
+
):
|
|
65
67
|
break
|
|
66
68
|
|
|
67
|
-
|
|
69
|
+
time_step = rng.exponential(1 / sum(rates))
|
|
70
|
+
if current_time + time_step >= next_change_time:
|
|
68
71
|
current_time = next_change_time
|
|
69
72
|
next_change_time = change_times.pop(0) if change_times else np.inf
|
|
70
73
|
continue
|
|
74
|
+
if current_time + time_step >= max_time:
|
|
75
|
+
current_time = max_time
|
|
76
|
+
break
|
|
77
|
+
current_time += time_step
|
|
71
78
|
|
|
72
79
|
event_idx = np.searchsorted(np.cumsum(rates) / sum(rates), rng.random())
|
|
73
|
-
events[int(event_idx)].apply(model, current_time)
|
|
80
|
+
events[int(event_idx)].apply(model, current_time, rng)
|
|
74
81
|
|
|
75
82
|
for individual in model.get_population():
|
|
76
83
|
if rng.random() < sampling_probability_at_present:
|
|
77
|
-
model.sample(individual, current_time,
|
|
84
|
+
model.sample(individual, current_time, True)
|
|
78
85
|
|
|
79
|
-
if model.n_sampled
|
|
80
|
-
max_tips is None or model.n_sampled <= max_tips
|
|
81
|
-
):
|
|
86
|
+
if min_tips <= model.n_sampled <= max_tips:
|
|
82
87
|
return model.get_sampled_tree()
|
|
83
|
-
|
|
84
88
|
n_tries += 1
|
|
89
|
+
|
|
90
|
+
print("WARNING: Maximum number of tries reached, returning None.")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def generate_trees(
|
|
94
|
+
output_dir: str,
|
|
95
|
+
n_trees: int,
|
|
96
|
+
events: Sequence[Event],
|
|
97
|
+
min_tips: int = 1,
|
|
98
|
+
max_tips: int = 2**32,
|
|
99
|
+
max_time: float = np.inf,
|
|
100
|
+
init_state: str | None = None,
|
|
101
|
+
sampling_probability_at_present: float = 0.0,
|
|
102
|
+
max_tries: int | None = None,
|
|
103
|
+
seed: int | None = None,
|
|
104
|
+
n_jobs: int = -1,
|
|
105
|
+
) -> None:
|
|
106
|
+
if os.path.exists(output_dir):
|
|
107
|
+
raise FileExistsError(f"Output directory {output_dir} already exists")
|
|
108
|
+
os.mkdir(output_dir)
|
|
109
|
+
|
|
110
|
+
rng = default_rng(seed)
|
|
111
|
+
jobs = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")(
|
|
112
|
+
joblib.delayed(simulate_tree)(
|
|
113
|
+
events=events,
|
|
114
|
+
min_tips=min_tips,
|
|
115
|
+
max_tips=max_tips,
|
|
116
|
+
max_time=max_time,
|
|
117
|
+
init_state=init_state,
|
|
118
|
+
sampling_probability_at_present=sampling_probability_at_present,
|
|
119
|
+
max_tries=max_tries,
|
|
120
|
+
seed=int(rng.integers(2**32)),
|
|
121
|
+
)
|
|
122
|
+
for _ in range(n_trees)
|
|
123
|
+
)
|
|
124
|
+
for i, tree in tqdm(
|
|
125
|
+
enumerate(jobs), total=n_trees, desc=f"Generating trees in {output_dir}..."
|
|
126
|
+
):
|
|
127
|
+
if tree is not None:
|
|
128
|
+
dump_newick(tree, os.path.join(output_dir, f"{i}.nwk"))
|
|
@@ -1,51 +1,25 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import ClassVar
|
|
4
|
-
|
|
5
|
-
from numpy.random import Generator, default_rng
|
|
2
|
+
from dataclasses import dataclass
|
|
6
3
|
|
|
7
4
|
from phylogenie.tree import Tree
|
|
8
5
|
|
|
9
|
-
CT_POSTFIX = "-CT"
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def get_CT_state(state: str) -> str:
|
|
13
|
-
return f"{state}{CT_POSTFIX}"
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def is_CT_state(state: str) -> bool:
|
|
17
|
-
return state.endswith(CT_POSTFIX)
|
|
18
|
-
|
|
19
6
|
|
|
20
7
|
@dataclass
|
|
21
8
|
class Individual:
|
|
9
|
+
id: int
|
|
22
10
|
node: Tree
|
|
23
11
|
state: str
|
|
24
|
-
id: int = field(init=False)
|
|
25
|
-
_id_counter: ClassVar[int] = 0
|
|
26
|
-
|
|
27
|
-
def __post_init__(self):
|
|
28
|
-
Individual._id_counter += 1
|
|
29
|
-
self.id = Individual._id_counter
|
|
30
12
|
|
|
31
13
|
|
|
32
14
|
class Model:
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
init_state: str,
|
|
36
|
-
max_notified_contacts: int = 1,
|
|
37
|
-
notification_probability: float = 0,
|
|
38
|
-
rng: int | Generator | None = None,
|
|
39
|
-
):
|
|
15
|
+
def __init__(self, init_state: str):
|
|
40
16
|
self._next_node_id = 0
|
|
17
|
+
self._next_individual_id = 0
|
|
41
18
|
self._population: dict[int, Individual] = {}
|
|
42
19
|
self._states: dict[str, set[int]] = defaultdict(set)
|
|
43
|
-
self.
|
|
20
|
+
self._lineages: dict[int, list[Individual]] = defaultdict(list)
|
|
44
21
|
self._sampled: set[str] = set()
|
|
45
22
|
self._tree = self._get_new_individual(init_state).node
|
|
46
|
-
self._max_notified_contacts = max_notified_contacts
|
|
47
|
-
self._notification_probability = notification_probability
|
|
48
|
-
self._rng = rng if isinstance(rng, Generator) else default_rng(rng)
|
|
49
23
|
|
|
50
24
|
@property
|
|
51
25
|
def n_sampled(self) -> int:
|
|
@@ -56,7 +30,10 @@ class Model:
|
|
|
56
30
|
return Tree(f"{self._next_node_id}|{state}")
|
|
57
31
|
|
|
58
32
|
def _get_new_individual(self, state: str) -> Individual:
|
|
59
|
-
|
|
33
|
+
self._next_individual_id += 1
|
|
34
|
+
individual = Individual(
|
|
35
|
+
self._next_individual_id, self._get_new_node(state), state
|
|
36
|
+
)
|
|
60
37
|
self._population[individual.id] = individual
|
|
61
38
|
self._states[state].add(individual.id)
|
|
62
39
|
return individual
|
|
@@ -93,12 +70,12 @@ class Model:
|
|
|
93
70
|
new_individual = self._get_new_individual(state)
|
|
94
71
|
individual.node.add_child(new_individual.node)
|
|
95
72
|
self._stem(individual, time)
|
|
96
|
-
self.
|
|
97
|
-
self.
|
|
73
|
+
self._lineages[id].append(new_individual)
|
|
74
|
+
self._lineages[new_individual.id].append(individual)
|
|
98
75
|
|
|
99
|
-
def sample(self, id: int, time: float,
|
|
76
|
+
def sample(self, id: int, time: float, removal: bool) -> None:
|
|
100
77
|
individual = self._population[id]
|
|
101
|
-
if
|
|
78
|
+
if removal:
|
|
102
79
|
self._sampled.add(individual.node.id)
|
|
103
80
|
self.remove(id, time)
|
|
104
81
|
else:
|
|
@@ -108,13 +85,11 @@ class Model:
|
|
|
108
85
|
individual.node.add_child(sample_node)
|
|
109
86
|
self._stem(individual, time)
|
|
110
87
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
):
|
|
117
|
-
self.migrate(contact.id, get_CT_state(contact.state), time)
|
|
88
|
+
def get_lineage(self, id: int) -> list[int]:
|
|
89
|
+
return [individual.id for individual in self._lineages[id]]
|
|
90
|
+
|
|
91
|
+
def get_state(self, id: int) -> str:
|
|
92
|
+
return self._population[id].state
|
|
118
93
|
|
|
119
94
|
def get_sampled_tree(self) -> Tree:
|
|
120
95
|
tree = self._tree.copy()
|
|
@@ -140,13 +115,10 @@ class Model:
|
|
|
140
115
|
def get_full_tree(self) -> Tree:
|
|
141
116
|
return self._tree.copy()
|
|
142
117
|
|
|
143
|
-
def
|
|
118
|
+
def get_population(self, state: str | None = None) -> list[int]:
|
|
144
119
|
if state is None:
|
|
145
|
-
return
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
def get_population(self) -> list[int]:
|
|
149
|
-
return list(self._population)
|
|
120
|
+
return list(self._population)
|
|
121
|
+
return list(self._states[state])
|
|
150
122
|
|
|
151
123
|
def count_individuals(self, state: str | None = None) -> int:
|
|
152
124
|
if state is None:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
phylogenie/__init__.py,sha256=sv8sfqkbarYgeaCFGiGMtV_fTUPuXEyJS_I3W9mlIto,1801
|
|
2
|
+
phylogenie/generators/__init__.py,sha256=zsOxy28-9j9alOQLIgrOAFfmM58NNHO_NEtW-KXQXAY,888
|
|
3
|
+
phylogenie/generators/alisim.py,sha256=dDqlSwLDbRE2u5SZlsq1mArobTBtuk0aeXY3m1N-bWA,2374
|
|
4
|
+
phylogenie/generators/configs.py,sha256=4jSBUZiFo2GacXWed5dy7lUEkaOWZkZG-KY9vHfhqGU,993
|
|
5
|
+
phylogenie/generators/dataset.py,sha256=UTsf8u868_8K6aMwIpLZrIfSY7s9skXlLUqQBuetiNQ,2954
|
|
6
|
+
phylogenie/generators/factories.py,sha256=Y7cTsIblyV9T7ZhYvyQ5Wd7JcCpMG1dkEF8jJ1iQxN8,6928
|
|
7
|
+
phylogenie/generators/trees.py,sha256=R5ZlZ-UNy2euPBt2P3lyit4txC3xweLrPPA1FI6m5PQ,9210
|
|
8
|
+
phylogenie/generators/typeguards.py,sha256=yj4VkhOaUXJ2OrY-6zhOeY9C4yKIQxjZtk2d-vIxttQ,828
|
|
9
|
+
phylogenie/io.py,sha256=d1xF6rwER6KpnA5OrFgDn7Ow-YymvxA7OXX7hi_nfB4,2951
|
|
10
|
+
phylogenie/main.py,sha256=vtvSpQxBNlYABoFQ25czl-l3fIr4QRo3svWVd-jcArw,1170
|
|
11
|
+
phylogenie/msa.py,sha256=JDGyZUsAq6-m-SQjoCDjAkAZIxfgyl_PDIhdYn5HOow,2064
|
|
12
|
+
phylogenie/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
phylogenie/skyline/__init__.py,sha256=7pF4CUb4ZCLzNYJNhOjpuTOLTRhlK7L6ugfccNqjIGo,620
|
|
14
|
+
phylogenie/skyline/matrix.py,sha256=Gl8OgKjtieG0NwPYiPimKI36gefV8fm_OeorjdXxPTs,9146
|
|
15
|
+
phylogenie/skyline/parameter.py,sha256=EM9qlPt0JhMBy3TbztM0dj24BaGNEy8KWKdTObDKhbI,4644
|
|
16
|
+
phylogenie/skyline/vector.py,sha256=bJP7_FNX_Klt6wXqsyfj0KX3VNj6-dIhzCKSJuQcOV0,7115
|
|
17
|
+
phylogenie/tree.py,sha256=dk8Sj1tqyGOunVO2crtIqb0LH-ws-PXqA8SuNcYfVHI,1738
|
|
18
|
+
phylogenie/treesimulator/__init__.py,sha256=DGn_sRDwL4OY1x1fT36kh4ghhwqSGt_8FnrV_TcQCjs,563
|
|
19
|
+
phylogenie/treesimulator/events.py,sha256=xV64Y_oH9tsAFVVJEGW6-VgiOcX-xNPR_niTxTmpARo,10583
|
|
20
|
+
phylogenie/treesimulator/gillespie.py,sha256=q5t0jfZWRqoyoiXiImMmo8fXqo7Cw1ea-OS_8aCD6Mc,4491
|
|
21
|
+
phylogenie/treesimulator/model.py,sha256=Zl82nlbuq0htrLZV7x5LAB-thuN4lzbDv5pDSrG3oM8,4595
|
|
22
|
+
phylogenie/typeguards.py,sha256=JtqmbEWJZBRHbWgCvcl6nrWm3VcBfzRbklbTBYHItn0,1325
|
|
23
|
+
phylogenie/typings.py,sha256=O1X6lGKTjJ2YJz3ApQ-rYb_tEJNUIcHdUIeYlSM4s5o,500
|
|
24
|
+
phylogenie-2.1.0.dist-info/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
|
|
25
|
+
phylogenie-2.1.0.dist-info/METADATA,sha256=5_QZd6c3rvlHwhRC-TT6bqcNG6gZDJJGLuNWVUorzlg,5472
|
|
26
|
+
phylogenie-2.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
27
|
+
phylogenie-2.1.0.dist-info/entry_points.txt,sha256=Rt6_usN0FkBX1ZfiqCirjMN9FKOgFLG8rydcQ8kugeE,51
|
|
28
|
+
phylogenie-2.1.0.dist-info/RECORD,,
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
phylogenie/__init__.py,sha256=4BytT42_M1K6T3W9eqHQCrKc6g0Lh5LTQxP8dGIJTsk,1915
|
|
2
|
-
phylogenie/generators/__init__.py,sha256=VCpuvmOoY_N6p1h_Q0peYgGIIUIFLsvZT3T7vbHG6w0,1090
|
|
3
|
-
phylogenie/generators/alisim.py,sha256=dDqlSwLDbRE2u5SZlsq1mArobTBtuk0aeXY3m1N-bWA,2374
|
|
4
|
-
phylogenie/generators/configs.py,sha256=4jSBUZiFo2GacXWed5dy7lUEkaOWZkZG-KY9vHfhqGU,993
|
|
5
|
-
phylogenie/generators/dataset.py,sha256=hbkN5McM4BKY7D0hLNaxdoAGsLHac6O-D4sgnZ0wFX4,2618
|
|
6
|
-
phylogenie/generators/factories.py,sha256=0ckeAsKnPy69Vbdoi1rIyf6zRcqamz9VfSi0mAiTzds,6938
|
|
7
|
-
phylogenie/generators/trees.py,sha256=jukaVXGcPGzDBEYMGJ1MKqWt4XbAB5EEfuHXDpwKTqM,9173
|
|
8
|
-
phylogenie/generators/typeguards.py,sha256=yj4VkhOaUXJ2OrY-6zhOeY9C4yKIQxjZtk2d-vIxttQ,828
|
|
9
|
-
phylogenie/io.py,sha256=ZXlofnSh7FX5UJiP0svRHrTraMSNgKa1GiAv0bMz7jU,2854
|
|
10
|
-
phylogenie/main.py,sha256=vtvSpQxBNlYABoFQ25czl-l3fIr4QRo3svWVd-jcArw,1170
|
|
11
|
-
phylogenie/msa.py,sha256=JDGyZUsAq6-m-SQjoCDjAkAZIxfgyl_PDIhdYn5HOow,2064
|
|
12
|
-
phylogenie/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
-
phylogenie/skyline/__init__.py,sha256=7pF4CUb4ZCLzNYJNhOjpuTOLTRhlK7L6ugfccNqjIGo,620
|
|
14
|
-
phylogenie/skyline/matrix.py,sha256=Gl8OgKjtieG0NwPYiPimKI36gefV8fm_OeorjdXxPTs,9146
|
|
15
|
-
phylogenie/skyline/parameter.py,sha256=EM9qlPt0JhMBy3TbztM0dj24BaGNEy8KWKdTObDKhbI,4644
|
|
16
|
-
phylogenie/skyline/vector.py,sha256=bJP7_FNX_Klt6wXqsyfj0KX3VNj6-dIhzCKSJuQcOV0,7115
|
|
17
|
-
phylogenie/tree.py,sha256=dk8Sj1tqyGOunVO2crtIqb0LH-ws-PXqA8SuNcYfVHI,1738
|
|
18
|
-
phylogenie/treesimulator/__init__.py,sha256=INPU9LrPdUmt3dYGzWDRoRKrPR9xENcHu44pJVUbyNA,525
|
|
19
|
-
phylogenie/treesimulator/events.py,sha256=X3_0U9qqMpYgh6-7TwQEnlUipANkHz6QTCXlm-qXFQk,9524
|
|
20
|
-
phylogenie/treesimulator/gillespie.py,sha256=4uMt_-Rr3cRXWGKC8veBIB-uqtKtN-dLbAHKjAi_5Mo,3182
|
|
21
|
-
phylogenie/treesimulator/model.py,sha256=XpzAicmg2O6K0Trk5YolH-B_HJZxoSauF2wZOMqp-Iw,5559
|
|
22
|
-
phylogenie/typeguards.py,sha256=JtqmbEWJZBRHbWgCvcl6nrWm3VcBfzRbklbTBYHItn0,1325
|
|
23
|
-
phylogenie/typings.py,sha256=O1X6lGKTjJ2YJz3ApQ-rYb_tEJNUIcHdUIeYlSM4s5o,500
|
|
24
|
-
phylogenie-2.0.13.dist-info/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
|
|
25
|
-
phylogenie-2.0.13.dist-info/METADATA,sha256=XkYNiu3IYt516JJ8dw-iIo9zonC8LFiB_ZVGhNyw-eY,5473
|
|
26
|
-
phylogenie-2.0.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
27
|
-
phylogenie-2.0.13.dist-info/entry_points.txt,sha256=Rt6_usN0FkBX1ZfiqCirjMN9FKOgFLG8rydcQ8kugeE,51
|
|
28
|
-
phylogenie-2.0.13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|