phylogenie 1.0.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. phylogenie/generators/__init__.py +14 -0
  2. phylogenie/generators/alisim.py +71 -0
  3. phylogenie/generators/configs.py +41 -0
  4. phylogenie/{core → generators}/dataset.py +25 -23
  5. phylogenie/{core → generators}/factories.py +42 -52
  6. phylogenie/generators/trees.py +220 -0
  7. phylogenie/generators/typeguards.py +32 -0
  8. phylogenie/io.py +92 -0
  9. phylogenie/main.py +2 -2
  10. phylogenie/msa.py +72 -0
  11. phylogenie/skyline/matrix.py +62 -45
  12. phylogenie/skyline/vector.py +8 -6
  13. phylogenie/tree.py +53 -0
  14. phylogenie/treesimulator/__init__.py +21 -0
  15. phylogenie/treesimulator/events.py +256 -0
  16. phylogenie/treesimulator/gillespie.py +66 -0
  17. phylogenie/treesimulator/model.py +100 -0
  18. phylogenie/typings.py +0 -2
  19. {phylogenie-1.0.8.dist-info → phylogenie-2.0.0.dist-info}/METADATA +6 -18
  20. phylogenie-2.0.0.dist-info/RECORD +28 -0
  21. phylogenie/backend/__init__.py +0 -0
  22. phylogenie/backend/remaster/__init__.py +0 -21
  23. phylogenie/backend/remaster/generate.py +0 -187
  24. phylogenie/backend/remaster/reactions.py +0 -165
  25. phylogenie/backend/treesimulator.py +0 -163
  26. phylogenie/configs.py +0 -5
  27. phylogenie/core/__init__.py +0 -14
  28. phylogenie/core/configs.py +0 -37
  29. phylogenie/core/context/__init__.py +0 -4
  30. phylogenie/core/context/configs.py +0 -28
  31. phylogenie/core/context/distributions.py +0 -125
  32. phylogenie/core/context/factories.py +0 -54
  33. phylogenie/core/msas/__init__.py +0 -10
  34. phylogenie/core/msas/alisim.py +0 -35
  35. phylogenie/core/msas/base.py +0 -51
  36. phylogenie/core/trees/__init__.py +0 -11
  37. phylogenie/core/trees/base.py +0 -13
  38. phylogenie/core/trees/remaster/__init__.py +0 -3
  39. phylogenie/core/trees/remaster/configs.py +0 -14
  40. phylogenie/core/trees/remaster/factories.py +0 -26
  41. phylogenie/core/trees/remaster/generator.py +0 -177
  42. phylogenie/core/trees/treesimulator.py +0 -199
  43. phylogenie/core/typeguards.py +0 -32
  44. phylogenie-1.0.8.dist-info/RECORD +0 -39
  45. {phylogenie-1.0.8.dist-info → phylogenie-2.0.0.dist-info}/LICENSE.txt +0 -0
  46. {phylogenie-1.0.8.dist-info → phylogenie-2.0.0.dist-info}/WHEEL +0 -0
  47. {phylogenie-1.0.8.dist-info → phylogenie-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,14 @@
1
+ from typing import Annotated
2
+
3
+ from pydantic import Field
4
+
5
+ from phylogenie.generators.alisim import AliSimDatasetGenerator
6
+ from phylogenie.generators.dataset import DatasetGenerator
7
+ from phylogenie.generators.trees import TreeDatasetGeneratorConfig
8
+
9
+ DatasetGeneratorConfig = Annotated[
10
+ TreeDatasetGeneratorConfig | AliSimDatasetGenerator,
11
+ Field(discriminator="data_type"),
12
+ ]
13
+
14
+ __all__ = ["DatasetGeneratorConfig", "DatasetGenerator"]
@@ -0,0 +1,71 @@
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+ from typing import Any, Literal
5
+
6
+ from numpy.random import Generator
7
+
8
+ from phylogenie.generators.dataset import DatasetGenerator, DataType
9
+ from phylogenie.generators.trees import TreeDatasetGeneratorConfig
10
+ from phylogenie.io import dump_newick
11
+
12
+ MSAS_DIRNAME = "MSAs"
13
+ TREES_DIRNAME = "trees"
14
+
15
+
16
+ class AliSimDatasetGenerator(DatasetGenerator):
17
+ data_type: Literal[DataType.MSAS] = DataType.MSAS
18
+ trees: TreeDatasetGeneratorConfig
19
+ keep_trees: bool = False
20
+ iqtree_path: str = "iqtree2"
21
+ args: dict[str, str | int | float]
22
+
23
+ def _generate_one_from_tree(
24
+ self, filename: str, tree_file: str, rng: Generator, data: dict[str, Any]
25
+ ) -> None:
26
+ command = [
27
+ self.iqtree_path,
28
+ "--alisim",
29
+ filename,
30
+ "--tree",
31
+ tree_file,
32
+ "--seed",
33
+ str(rng.integers(2**32)),
34
+ ]
35
+
36
+ for key, value in self.args.items():
37
+ command.extend(
38
+ [key, value.format(**data) if isinstance(value, str) else str(value)]
39
+ )
40
+
41
+ command.extend(["-af", "fasta"])
42
+ subprocess.run(command, check=True, stdout=subprocess.DEVNULL)
43
+ subprocess.run(["rm", f"{tree_file}.log"], check=True)
44
+
45
+ def _generate_one(
46
+ self, filename: str, rng: Generator, data: dict[str, Any]
47
+ ) -> None:
48
+ if self.keep_trees:
49
+ base_dir = Path(filename).parent
50
+ file_id = Path(filename).stem
51
+ tree_filename = os.path.join(base_dir, TREES_DIRNAME, file_id)
52
+ msas_dir = os.path.join(base_dir, MSAS_DIRNAME)
53
+ os.makedirs(msas_dir, exist_ok=True)
54
+ msa_filename = os.path.join(msas_dir, file_id)
55
+ else:
56
+ tree_filename = f"{filename}.temp-tree"
57
+ msa_filename = filename
58
+
59
+ tree = self.trees.simulate_one(rng, data)
60
+ if tree is None:
61
+ return
62
+
63
+ for leaf in tree.get_leaves():
64
+ leaf.id += f"|{leaf.get_time()}"
65
+ dump_newick(tree, f"{tree_filename}.nwk")
66
+
67
+ self._generate_one_from_tree(
68
+ filename=msa_filename, tree_file=f"{tree_filename}.nwk", rng=rng, data=data
69
+ )
70
+ if not self.keep_trees:
71
+ os.remove(f"{tree_filename}.nwk")
@@ -0,0 +1,41 @@
1
+ from pydantic import BaseModel, ConfigDict
2
+
3
+ import phylogenie.typings as pgt
4
+
5
+
6
+ class DistributionConfig(BaseModel):
7
+ type: str
8
+ model_config = ConfigDict(extra="allow")
9
+
10
+
11
+ IntegerConfig = str | int
12
+ ScalarConfig = str | pgt.Scalar
13
+ ManyScalarsConfig = str | list[ScalarConfig]
14
+ OneOrManyScalarsConfig = ScalarConfig | list[ScalarConfig]
15
+ OneOrMany2DScalarsConfig = ScalarConfig | list[list[ScalarConfig]]
16
+
17
+
18
+ class StrictBaseModel(BaseModel):
19
+ model_config = ConfigDict(extra="forbid")
20
+
21
+
22
+ class SkylineParameterModel(StrictBaseModel):
23
+ value: ManyScalarsConfig
24
+ change_times: ManyScalarsConfig
25
+
26
+
27
+ class SkylineVectorModel(StrictBaseModel):
28
+ value: str | list[OneOrManyScalarsConfig]
29
+ change_times: ManyScalarsConfig
30
+
31
+
32
+ class SkylineMatrixModel(StrictBaseModel):
33
+ value: str | list[OneOrMany2DScalarsConfig]
34
+ change_times: ManyScalarsConfig
35
+
36
+
37
+ SkylineParameterConfig = ScalarConfig | SkylineParameterModel
38
+ SkylineVectorConfig = (
39
+ str | pgt.Scalar | list[SkylineParameterConfig] | SkylineVectorModel
40
+ )
41
+ SkylineMatrixConfig = str | pgt.Scalar | list[SkylineVectorConfig] | SkylineMatrixModel
@@ -1,15 +1,16 @@
1
1
  import os
2
2
  from abc import ABC, abstractmethod
3
3
  from enum import Enum
4
+ from itertools import product
5
+ from typing import Any
4
6
 
5
7
  import joblib
8
+ import numpy as np
6
9
  import pandas as pd
7
10
  from numpy.random import Generator, default_rng
8
11
  from tqdm import tqdm
9
12
 
10
- import phylogenie.typings as pgt
11
- from phylogenie.configs import StrictBaseModel
12
- from phylogenie.core.context import ContextConfig, context_factory
13
+ from phylogenie.generators.configs import DistributionConfig, StrictBaseModel
13
14
 
14
15
 
15
16
  class DataType(str, Enum):
@@ -17,51 +18,52 @@ class DataType(str, Enum):
17
18
  MSAS = "msas"
18
19
 
19
20
 
21
+ DATA_DIRNAME = "data"
22
+ METADATA_FILENAME = "metadata.csv"
23
+
24
+
20
25
  class DatasetGenerator(ABC, StrictBaseModel):
21
- output_dir: str = "phylogenie-out"
22
- data_dir: str = "data"
23
- metadata_filename: str = "metadata.csv"
26
+ output_dir: str = "phylogenie-outputs"
24
27
  n_samples: int | dict[str, int] = 1
25
28
  n_jobs: int = -1
26
29
  seed: int | None = None
27
- context: ContextConfig | None = None
30
+ context: dict[str, DistributionConfig] | None = None
28
31
 
29
32
  @abstractmethod
30
- def _generate_one(self, filename: str, rng: Generator, data: pgt.Data) -> None: ...
33
+ def _generate_one(
34
+ self, filename: str, rng: Generator, data: dict[str, Any]
35
+ ) -> None: ...
31
36
 
32
37
  def generate_one(
33
- self, filename: str, data: pgt.Data | None = None, seed: int | None = None
38
+ self, filename: str, data: dict[str, Any] | None = None, seed: int | None = None
34
39
  ) -> None:
35
40
  data = {} if data is None else data
36
41
  self._generate_one(filename=filename, rng=default_rng(seed), data=data)
37
42
 
38
43
  def _generate(self, rng: Generator, n_samples: int, output_dir: str) -> None:
39
- data_dir = os.path.join(output_dir, self.data_dir)
40
- metadata_file = os.path.join(output_dir, self.metadata_filename)
44
+ data_dir = os.path.join(output_dir, DATA_DIRNAME)
41
45
  if os.path.exists(data_dir):
42
46
  print(f"Output directory {data_dir} already exists. Skipping.")
43
47
  return
44
48
  os.makedirs(data_dir)
45
49
 
46
- data = [
47
- {} if self.context is None else context_factory(self.context, rng)
48
- for _ in range(n_samples)
49
- ]
50
+ data: list[dict[str, Any]] = [{}] * n_samples
51
+ if self.context is not None:
52
+ for d, (k, v) in product(data, self.context.items()):
53
+ args = v.model_extra if v.model_extra is not None else {}
54
+ d[k] = np.array(getattr(rng, v.type)(**args)).tolist()
55
+ df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
56
+ df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
50
57
 
51
58
  joblib.Parallel(n_jobs=self.n_jobs)(
52
59
  joblib.delayed(self.generate_one)(
53
60
  filename=os.path.join(data_dir, str(i)),
54
- data=d,
55
- seed=int(rng.integers(0, 2**32)),
56
- )
57
- for i, d in tqdm(
58
- enumerate(data), total=n_samples, desc=f"Generating {data_dir}..."
61
+ data=data[i],
62
+ seed=int(rng.integers(2**32)),
59
63
  )
64
+ for i in tqdm(range(n_samples), desc=f"Generating {data_dir}...")
60
65
  )
61
66
 
62
- df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
63
- df.to_csv(metadata_file, index=False)
64
-
65
67
  def generate(self) -> None:
66
68
  rng = default_rng(self.seed)
67
69
  if isinstance(self.n_samples, dict):
@@ -2,8 +2,8 @@ from typing import Any
2
2
 
3
3
  import numpy as np
4
4
 
5
- import phylogenie.core.configs as cfg
6
- import phylogenie.core.typeguards as ctg
5
+ import phylogenie.generators.configs as cfg
6
+ import phylogenie.generators.typeguards as ctg
7
7
  import phylogenie.typeguards as tg
8
8
  import phylogenie.typings as pgt
9
9
  from phylogenie.skyline import (
@@ -16,7 +16,7 @@ from phylogenie.skyline import (
16
16
  )
17
17
 
18
18
 
19
- def _eval_expression(expression: str, data: pgt.Data) -> Any:
19
+ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
20
20
  return np.array(
21
21
  eval(
22
22
  expression,
@@ -29,7 +29,7 @@ def _eval_expression(expression: str, data: pgt.Data) -> Any:
29
29
  ).tolist()
30
30
 
31
31
 
32
- def int_factory(x: cfg.IntConfig, data: pgt.Data) -> int:
32
+ def integer(x: cfg.IntegerConfig, data: dict[str, Any]) -> int:
33
33
  if isinstance(x, str):
34
34
  e = _eval_expression(x, data)
35
35
  if isinstance(e, int):
@@ -40,7 +40,7 @@ def int_factory(x: cfg.IntConfig, data: pgt.Data) -> int:
40
40
  return x
41
41
 
42
42
 
43
- def scalar_factory(x: cfg.ScalarConfig, data: pgt.Data) -> pgt.Scalar:
43
+ def scalar(x: cfg.ScalarConfig, data: dict[str, Any]) -> pgt.Scalar:
44
44
  if isinstance(x, str):
45
45
  e = _eval_expression(x, data)
46
46
  if isinstance(e, pgt.Scalar):
@@ -51,18 +51,7 @@ def scalar_factory(x: cfg.ScalarConfig, data: pgt.Data) -> pgt.Scalar:
51
51
  return x
52
52
 
53
53
 
54
- def many_ints_factory(x: cfg.ManyIntsConfig, data: pgt.Data) -> pgt.Many[int]:
55
- if isinstance(x, str):
56
- e = _eval_expression(x, data)
57
- if tg.is_many_ints(e):
58
- return e
59
- raise ValueError(
60
- f"Expression '{x}' evaluated to {e} of type {type(e)}, expected a sequence of integers."
61
- )
62
- return [int_factory(v, data) for v in x]
63
-
64
-
65
- def many_scalars_factory(x: cfg.ManyScalarsConfig, data: pgt.Data) -> pgt.ManyScalars:
54
+ def many_scalars(x: cfg.ManyScalarsConfig, data: dict[str, Any]) -> pgt.ManyScalars:
66
55
  if isinstance(x, str):
67
56
  e = _eval_expression(x, data)
68
57
  if tg.is_many_scalars(e):
@@ -70,11 +59,11 @@ def many_scalars_factory(x: cfg.ManyScalarsConfig, data: pgt.Data) -> pgt.ManySc
70
59
  raise ValueError(
71
60
  f"Expression '{x}' evaluated to {e} of type {type(e)}, expected a sequence of scalars."
72
61
  )
73
- return [scalar_factory(v, data) for v in x]
62
+ return [scalar(v, data) for v in x]
74
63
 
75
64
 
76
- def one_or_many_scalars_factory(
77
- x: cfg.OneOrManyScalarsConfig, data: pgt.Data
65
+ def one_or_many_scalars(
66
+ x: cfg.OneOrManyScalarsConfig, data: dict[str, Any]
78
67
  ) -> pgt.OneOrManyScalars:
79
68
  if isinstance(x, str):
80
69
  e = _eval_expression(x, data)
@@ -85,22 +74,22 @@ def one_or_many_scalars_factory(
85
74
  )
86
75
  if isinstance(x, pgt.Scalar):
87
76
  return x
88
- return many_scalars_factory(x, data)
77
+ return many_scalars(x, data)
89
78
 
90
79
 
91
- def skyline_parameter_like_factory(
92
- x: cfg.SkylineParameterLikeConfig, data: pgt.Data
80
+ def skyline_parameter(
81
+ x: cfg.SkylineParameterConfig, data: dict[str, Any]
93
82
  ) -> SkylineParameterLike:
94
83
  if isinstance(x, cfg.ScalarConfig):
95
- return scalar_factory(x, data)
84
+ return scalar(x, data)
96
85
  return SkylineParameter(
97
- value=many_scalars_factory(x.value, data),
98
- change_times=many_scalars_factory(x.change_times, data),
86
+ value=many_scalars(x.value, data),
87
+ change_times=many_scalars(x.change_times, data),
99
88
  )
100
89
 
101
90
 
102
- def skyline_vector_coercible_factory(
103
- x: cfg.SkylineVectorCoercibleConfig, data: pgt.Data
91
+ def skyline_vector(
92
+ x: cfg.SkylineVectorConfig, data: dict[str, Any]
104
93
  ) -> SkylineVectorCoercible:
105
94
  if isinstance(x, str):
106
95
  e = _eval_expression(x, data)
@@ -111,12 +100,12 @@ def skyline_vector_coercible_factory(
111
100
  )
112
101
  if isinstance(x, pgt.Scalar):
113
102
  return x
114
- if ctg.is_list_of_skyline_parameter_like_configs(x):
115
- return [skyline_parameter_like_factory(p, data) for p in x]
103
+ if ctg.is_list_of_skyline_parameter_configs(x):
104
+ return [skyline_parameter(p, data) for p in x]
116
105
 
117
- assert isinstance(x, cfg.SkylineVectorValueModel)
106
+ assert isinstance(x, cfg.SkylineVectorModel)
118
107
 
119
- change_times = many_scalars_factory(x.change_times, data)
108
+ change_times = many_scalars(x.change_times, data)
120
109
  if isinstance(x.value, str):
121
110
  e = _eval_expression(x.value, data)
122
111
  if tg.is_many_one_or_many_scalars(e):
@@ -126,7 +115,7 @@ def skyline_vector_coercible_factory(
126
115
  f"Expression '{x.value}' evaluated to {e} of type {type(e)}, which cannot be coerced to a valid value for a SkylineVector (expected a sequence composed of scalars and/or sequences of scalars)."
127
116
  )
128
117
  else:
129
- value = [one_or_many_scalars_factory(v, data) for v in x.value]
118
+ value = [one_or_many_scalars(v, data) for v in x.value]
130
119
 
131
120
  if tg.is_many_scalars(value):
132
121
  return SkylineParameter(value=value, change_times=change_times)
@@ -142,8 +131,8 @@ def skyline_vector_coercible_factory(
142
131
  return SkylineVector(value=value, change_times=change_times)
143
132
 
144
133
 
145
- def one_or_many_2D_scalars_factory(
146
- x: cfg.OneOrMany2DScalarsConfig, data: pgt.Data
134
+ def one_or_many_2D_scalars(
135
+ x: cfg.OneOrMany2DScalarsConfig, data: dict[str, Any]
147
136
  ) -> pgt.OneOrMany2DScalars:
148
137
  if isinstance(x, str):
149
138
  e = _eval_expression(x, data)
@@ -154,11 +143,11 @@ def one_or_many_2D_scalars_factory(
154
143
  )
155
144
  if isinstance(x, pgt.Scalar):
156
145
  return x
157
- return [many_scalars_factory(v, data) for v in x]
146
+ return [many_scalars(v, data) for v in x]
158
147
 
159
148
 
160
- def skyline_matrix_coercible_factory(
161
- x: cfg.SkylineMatrixCoercibleConfig, data: pgt.Data
149
+ def skyline_matrix(
150
+ x: cfg.SkylineMatrixConfig, data: dict[str, Any]
162
151
  ) -> SkylineMatrixCoercible:
163
152
  if isinstance(x, str):
164
153
  e = _eval_expression(x, data)
@@ -169,12 +158,12 @@ def skyline_matrix_coercible_factory(
169
158
  )
170
159
  if isinstance(x, pgt.Scalar):
171
160
  return x
172
- if ctg.is_list_of_skyline_vector_coercible_configs(x):
173
- return [skyline_vector_coercible_factory(v, data) for v in x]
161
+ if ctg.is_list_of_skyline_vector_configs(x):
162
+ return [skyline_vector(v, data) for v in x]
174
163
 
175
- assert isinstance(x, cfg.SkylineMatrixValueModel)
164
+ assert isinstance(x, cfg.SkylineMatrixModel)
176
165
 
177
- change_times = many_scalars_factory(x.change_times, data)
166
+ change_times = many_scalars(x.change_times, data)
178
167
  if isinstance(x.value, str):
179
168
  e = _eval_expression(x.value, data)
180
169
  if tg.is_many_one_or_many_2D_scalars(e):
@@ -184,26 +173,27 @@ def skyline_matrix_coercible_factory(
184
173
  f"Expression '{x.value}' evaluated to {e} of type {type(e)}, which cannot be coerced to a valid value for a SkylineMatrix (expected a sequence composed of scalars and/or nested (2D) sequences of scalars)."
185
174
  )
186
175
  else:
187
- value = [one_or_many_2D_scalars_factory(v, data) for v in x.value]
176
+ value = [one_or_many_2D_scalars(v, data) for v in x.value]
188
177
 
189
178
  if tg.is_many_scalars(value):
190
179
  return SkylineParameter(value=value, change_times=change_times)
191
180
 
192
- Ns: set[int] = set()
181
+ shapes: set[tuple[int, int]] = set()
193
182
  for elem in value:
194
183
  if tg.is_many_2D_scalars(elem):
195
- n_rows = len(elem)
196
- if any(len(row) != n_rows for row in elem):
184
+ Ms = len(elem)
185
+ Ns = {len(row) for row in elem}
186
+ if len(Ns) > 1:
197
187
  raise ValueError(
198
- f"All elements in the value of a SkylineMatrix config must be scalars or square matrices (config {x.value} yeilded a non-square matrix: {elem})."
188
+ f"The values of a SkylineMatrix config must be scalars or nested (2D) lists of them with a consistent row length (config {x.value} yielded element {elem} with row lengths {Ns})."
199
189
  )
200
- Ns.add(n_rows)
190
+ shapes.add((Ms, Ns.pop()))
201
191
 
202
- if len(Ns) > 1:
192
+ if len(shapes) > 1:
203
193
  raise ValueError(
204
- f"All elements in the value of a SkylineMatrix config must be scalars or have the same square shape (config {x.value} yielded value={value} with inconsistent lengths {Ns})."
194
+ f"All elements in the value of a SkylineMatrix config must be scalars or nested (2D) lists of them with the same shape (config {x.value} yielded value={value} with inconsistent shapes {shapes})."
205
195
  )
206
- (N,) = Ns
207
- value = [[[p] * N] * N if isinstance(p, pgt.Scalar) else p for p in value]
196
+ ((M, N),) = shapes
197
+ value = [[[e] * N] * M if isinstance(e, pgt.Scalar) else e for e in value]
208
198
 
209
199
  return SkylineMatrix(value=value, change_times=change_times)
@@ -0,0 +1,220 @@
1
+ from abc import abstractmethod
2
+ from enum import Enum
3
+ from typing import Annotated, Any, Literal
4
+
5
+ import numpy as np
6
+ from numpy.random import Generator
7
+ from pydantic import Field
8
+
9
+ import phylogenie.generators.configs as cfg
10
+ from phylogenie.generators.dataset import DatasetGenerator, DataType
11
+ from phylogenie.generators.factories import (
12
+ integer,
13
+ scalar,
14
+ skyline_matrix,
15
+ skyline_parameter,
16
+ skyline_vector,
17
+ )
18
+ from phylogenie.io import dump_newick
19
+ from phylogenie.tree import Tree
20
+ from phylogenie.treesimulator import (
21
+ Event,
22
+ get_BD_events,
23
+ get_BDEI_events,
24
+ get_BDSS_events,
25
+ get_canonical_events,
26
+ get_epidemiological_events,
27
+ get_FBD_events,
28
+ simulate_tree,
29
+ )
30
+
31
+
32
+ class ParameterizationType(str, Enum):
33
+ CANONICAL = "canonical"
34
+ EPIDEMIOLOGICAL = "epidemiological"
35
+ FBD = "FBD"
36
+ BD = "BD"
37
+ BDEI = "BDEI"
38
+ BDSS = "BDSS"
39
+
40
+
41
+ class TreeDatasetGenerator(DatasetGenerator):
42
+ data_type: Literal[DataType.TREES] = DataType.TREES
43
+ min_tips: cfg.IntegerConfig = 1
44
+ max_tips: cfg.IntegerConfig | None = None
45
+ max_time: cfg.ScalarConfig = np.inf
46
+ init_state: str | None = None
47
+ sampling_probability_at_present: cfg.ScalarConfig = 0.0
48
+ max_tries: int | None = None
49
+
50
+ def simulate_one(self, rng: Generator, data: dict[str, Any]) -> Tree | None:
51
+ events = self._get_events(rng, data)
52
+ init_state = (
53
+ self.init_state
54
+ if self.init_state is None
55
+ else self.init_state.format(**data)
56
+ )
57
+ max_tips = (
58
+ self.max_tips if self.max_tips is None else integer(self.max_tips, data)
59
+ )
60
+ return simulate_tree(
61
+ events=events,
62
+ min_tips=integer(self.min_tips, data),
63
+ max_tips=max_tips,
64
+ max_time=scalar(self.max_time, data),
65
+ init_state=init_state,
66
+ sampling_probability_at_present=scalar(
67
+ self.sampling_probability_at_present, data
68
+ ),
69
+ max_tries=self.max_tries,
70
+ seed=int(rng.integers(2**32)),
71
+ )
72
+
73
+ @abstractmethod
74
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]: ...
75
+
76
+ def _generate_one(
77
+ self, filename: str, rng: Generator, data: dict[str, Any]
78
+ ) -> None:
79
+ tree = self.simulate_one(rng, data)
80
+ if tree is not None:
81
+ dump_newick(tree, f"{filename}.nwk")
82
+
83
+
84
+ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
85
+ parameterization: Literal[ParameterizationType.CANONICAL] = (
86
+ ParameterizationType.CANONICAL
87
+ )
88
+ sampling_rates: cfg.SkylineVectorConfig
89
+ birth_rates: cfg.SkylineVectorConfig = 0
90
+ death_rates: cfg.SkylineVectorConfig = 0
91
+ removal_probabilities: cfg.SkylineVectorConfig = 0
92
+ migration_rates: cfg.SkylineMatrixConfig = 0
93
+ birth_rates_among_states: cfg.SkylineMatrixConfig = 0
94
+ states: list[str] | None = None
95
+
96
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
97
+ return get_canonical_events(
98
+ states=self.states,
99
+ sampling_rates=skyline_vector(self.sampling_rates, data),
100
+ birth_rates=skyline_vector(self.birth_rates, data),
101
+ death_rates=skyline_vector(self.death_rates, data),
102
+ removal_probabilities=skyline_vector(self.removal_probabilities, data),
103
+ migration_rates=skyline_matrix(self.migration_rates, data),
104
+ birth_rates_among_states=skyline_matrix(
105
+ self.birth_rates_among_states, data
106
+ ),
107
+ )
108
+
109
+
110
+ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
111
+ parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
112
+ ParameterizationType.EPIDEMIOLOGICAL
113
+ )
114
+ states: list[str] | None = None
115
+ reproduction_numbers: cfg.SkylineVectorConfig = 0
116
+ become_uninfectious_rates: cfg.SkylineVectorConfig = 0
117
+ sampling_proportions: cfg.SkylineVectorConfig = 1
118
+ removal_probabilities: cfg.SkylineVectorConfig = 1
119
+ migration_rates: cfg.SkylineMatrixConfig = 0
120
+ reproduction_numbers_among_states: cfg.SkylineMatrixConfig = 0
121
+
122
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
123
+ return get_epidemiological_events(
124
+ states=self.states,
125
+ reproduction_numbers=skyline_vector(self.reproduction_numbers, data),
126
+ become_uninfectious_rates=skyline_vector(
127
+ self.become_uninfectious_rates, data
128
+ ),
129
+ sampling_proportions=skyline_vector(self.sampling_proportions, data),
130
+ removal_probabilities=skyline_vector(self.removal_probabilities, data),
131
+ migration_rates=skyline_matrix(self.migration_rates, data),
132
+ reproduction_numbers_among_states=skyline_matrix(
133
+ self.reproduction_numbers_among_states, data
134
+ ),
135
+ )
136
+
137
+
138
+ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
139
+ parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
140
+ states: list[str] | None = None
141
+ diversification: cfg.SkylineVectorConfig = 0
142
+ turnover: cfg.SkylineVectorConfig = 0
143
+ sampling_proportions: cfg.SkylineVectorConfig = 1
144
+ removal_probabilities: cfg.SkylineVectorConfig = 0
145
+ migration_rates: cfg.SkylineMatrixConfig = 0
146
+ diversification_between_types: cfg.SkylineMatrixConfig = 0
147
+
148
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
149
+ return get_FBD_events(
150
+ states=self.states,
151
+ diversification=skyline_vector(self.diversification, data),
152
+ turnover=skyline_vector(self.turnover, data),
153
+ sampling_proportions=skyline_vector(self.sampling_proportions, data),
154
+ removal_probabilities=skyline_vector(self.removal_probabilities, data),
155
+ migration_rates=skyline_matrix(self.migration_rates, data),
156
+ diversification_between_types=skyline_matrix(
157
+ self.diversification_between_types, data
158
+ ),
159
+ )
160
+
161
+
162
+ class BDTreeDatasetGenerator(TreeDatasetGenerator):
163
+ parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
164
+ reproduction_number: cfg.SkylineParameterConfig
165
+ infectious_period: cfg.SkylineParameterConfig
166
+ sampling_proportion: cfg.SkylineParameterConfig = 1
167
+
168
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
169
+ return get_BD_events(
170
+ reproduction_number=skyline_parameter(self.reproduction_number, data),
171
+ infectious_period=skyline_parameter(self.infectious_period, data),
172
+ sampling_proportion=skyline_parameter(self.sampling_proportion, data),
173
+ )
174
+
175
+
176
+ class BDEITreeDatasetGenerator(TreeDatasetGenerator):
177
+ parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
178
+ reproduction_number: cfg.SkylineParameterConfig
179
+ infectious_period: cfg.SkylineParameterConfig
180
+ incubation_period: cfg.SkylineParameterConfig
181
+ sampling_proportion: cfg.SkylineParameterConfig = 1
182
+
183
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
184
+ return get_BDEI_events(
185
+ reproduction_number=skyline_parameter(self.reproduction_number, data),
186
+ infectious_period=skyline_parameter(self.infectious_period, data),
187
+ incubation_period=skyline_parameter(self.incubation_period, data),
188
+ sampling_proportion=skyline_parameter(self.sampling_proportion, data),
189
+ )
190
+
191
+
192
+ class BDSSTreeDatasetGenerator(TreeDatasetGenerator):
193
+ parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
194
+ reproduction_number: cfg.SkylineParameterConfig
195
+ infectious_period: cfg.SkylineParameterConfig
196
+ superspreading_ratio: cfg.SkylineParameterConfig
197
+ superspreaders_proportion: cfg.SkylineParameterConfig
198
+ sampling_proportion: cfg.SkylineParameterConfig = 1
199
+
200
+ def _get_events(self, rng: Generator, data: dict[str, Any]) -> list[Event]:
201
+ return get_BDSS_events(
202
+ reproduction_number=skyline_parameter(self.reproduction_number, data),
203
+ infectious_period=skyline_parameter(self.infectious_period, data),
204
+ superspreading_ratio=skyline_parameter(self.superspreading_ratio, data),
205
+ superspreaders_proportion=skyline_parameter(
206
+ self.superspreaders_proportion, data
207
+ ),
208
+ sampling_proportion=skyline_parameter(self.sampling_proportion, data),
209
+ )
210
+
211
+
212
+ TreeDatasetGeneratorConfig = Annotated[
213
+ CanonicalTreeDatasetGenerator
214
+ | EpidemiologicalTreeDatasetGenerator
215
+ | FBDTreeDatasetGenerator
216
+ | BDTreeDatasetGenerator
217
+ | BDEITreeDatasetGenerator
218
+ | BDSSTreeDatasetGenerator,
219
+ Field(discriminator="parameterization"),
220
+ ]
@@ -0,0 +1,32 @@
1
+ from typing import TypeGuard
2
+
3
+ import phylogenie.generators.configs as cfg
4
+ import phylogenie.typings as pgt
5
+
6
+
7
+ def is_list(x: object) -> TypeGuard[list[object]]:
8
+ return isinstance(x, list)
9
+
10
+
11
+ def is_list_of_scalar_configs(x: object) -> TypeGuard[list[cfg.ScalarConfig]]:
12
+ return is_list(x) and all(isinstance(v, cfg.ScalarConfig) for v in x)
13
+
14
+
15
+ def is_list_of_skyline_parameter_configs(
16
+ x: object,
17
+ ) -> TypeGuard[list[cfg.SkylineParameterConfig]]:
18
+ return is_list(x) and all(isinstance(v, cfg.SkylineParameterConfig) for v in x)
19
+
20
+
21
+ def is_skyline_vector_config(
22
+ x: object,
23
+ ) -> TypeGuard[cfg.SkylineVectorConfig]:
24
+ return isinstance(
25
+ x, str | pgt.Scalar | cfg.SkylineVectorModel
26
+ ) or is_list_of_skyline_parameter_configs(x)
27
+
28
+
29
+ def is_list_of_skyline_vector_configs(
30
+ x: object,
31
+ ) -> TypeGuard[list[cfg.SkylineVectorConfig]]:
32
+ return is_list(x) and all(is_skyline_vector_config(v) for v in x)