phylogenie 2.0.14__tar.gz → 2.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {phylogenie-2.0.14 → phylogenie-2.1.1}/PKG-INFO +1 -2
  2. {phylogenie-2.0.14 → phylogenie-2.1.1}/README.md +0 -1
  3. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/__init__.py +22 -7
  4. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/__init__.py +0 -8
  5. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/configs.py +11 -12
  6. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/dataset.py +10 -7
  7. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/factories.py +9 -0
  8. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/trees.py +88 -61
  9. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/io.py +27 -12
  10. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/tree.py +38 -12
  11. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/treesimulator/__init__.py +18 -1
  12. phylogenie-2.1.1/phylogenie/treesimulator/events/__init__.py +39 -0
  13. phylogenie-2.1.1/phylogenie/treesimulator/events/contact_tracing.py +125 -0
  14. phylogenie-2.0.14/phylogenie/treesimulator/events.py → phylogenie-2.1.1/phylogenie/treesimulator/events/core.py +73 -125
  15. phylogenie-2.1.1/phylogenie/treesimulator/events/mutations.py +105 -0
  16. phylogenie-2.1.1/phylogenie/treesimulator/gillespie.py +123 -0
  17. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/treesimulator/model.py +57 -56
  18. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/typings.py +0 -1
  19. phylogenie-2.1.1/phylogenie/utils.py +17 -0
  20. {phylogenie-2.0.14 → phylogenie-2.1.1}/pyproject.toml +1 -1
  21. phylogenie-2.0.14/phylogenie/treesimulator/gillespie.py +0 -86
  22. {phylogenie-2.0.14 → phylogenie-2.1.1}/LICENSE.txt +0 -0
  23. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/alisim.py +0 -0
  24. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/generators/typeguards.py +0 -0
  25. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/main.py +0 -0
  26. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/msa.py +0 -0
  27. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/py.typed +0 -0
  28. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/skyline/__init__.py +0 -0
  29. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/skyline/matrix.py +0 -0
  30. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/skyline/parameter.py +0 -0
  31. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/skyline/vector.py +0 -0
  32. {phylogenie-2.0.14 → phylogenie-2.1.1}/phylogenie/typeguards.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: phylogenie
3
- Version: 2.0.14
3
+ Version: 2.1.1
4
4
  Summary: Generate phylogenetic datasets with minimal setup effort
5
5
  Author: Gabriele Marino
6
6
  Author-email: gabmarino.8601@gmail.com
@@ -24,7 +24,6 @@ Description-Content-Type: text/markdown
24
24
 
25
25
  [![AliSim](https://img.shields.io/badge/Powered%20by-AliSim-orange?style=flat-square)](https://iqtree.github.io/doc/AliSim)
26
26
  [![PyPI version](https://img.shields.io/pypi/v/phylogenie)](https://pypi.org/project/phylogenie/)
27
- [![PyPI downloads](https://shields.io/pypi/dm/phylogenie)](https://pypi.org/project/phylogenie/)
28
27
 
29
28
  Phylogenie is a [Python](https://www.python.org/) package designed to easily simulate phylogenetic datasets—such as trees and multiple sequence alignments (MSAs)—with minimal setup effort. Simply specify the distributions from which your parameters should be sampled, and Phylogenie will handle the rest!
30
29
 
@@ -6,7 +6,6 @@
6
6
 
7
7
  [![AliSim](https://img.shields.io/badge/Powered%20by-AliSim-orange?style=flat-square)](https://iqtree.github.io/doc/AliSim)
8
8
  [![PyPI version](https://img.shields.io/pypi/v/phylogenie)](https://pypi.org/project/phylogenie/)
9
- [![PyPI downloads](https://shields.io/pypi/dm/phylogenie)](https://pypi.org/project/phylogenie/)
10
9
 
11
10
  Phylogenie is a [Python](https://www.python.org/) package designed to easily simulate phylogenetic datasets—such as trees and multiple sequence alignments (MSAs)—with minimal setup effort. Simply specify the distributions from which your parameters should be sampled, and Phylogenie will handle the rest!
12
11
 
@@ -8,12 +8,9 @@ from phylogenie.generators import (
8
8
  DatasetGeneratorConfig,
9
9
  EpidemiologicalTreeDatasetGenerator,
10
10
  FBDTreeDatasetGenerator,
11
- SkylineMatrixModel,
12
- SkylineParameterModel,
13
- SkylineVectorModel,
14
11
  TreeDatasetGeneratorConfig,
15
12
  )
16
- from phylogenie.io import load_fasta, load_newick
13
+ from phylogenie.io import dump_newick, load_fasta, load_newick
17
14
  from phylogenie.msa import MSA
18
15
  from phylogenie.skyline import (
19
16
  SkylineMatrix,
@@ -29,11 +26,21 @@ from phylogenie.skyline import (
29
26
  )
30
27
  from phylogenie.tree import Tree
31
28
  from phylogenie.treesimulator import (
29
+ Birth,
30
+ BirthWithContactTracing,
31
+ Death,
32
32
  Event,
33
+ Migration,
34
+ Mutation,
35
+ MutationTargetType,
36
+ Sampling,
37
+ SamplingWithContactTracing,
38
+ generate_trees,
33
39
  get_BD_events,
34
40
  get_BDEI_events,
35
41
  get_BDSS_events,
36
42
  get_canonical_events,
43
+ get_contact_tracing_events,
37
44
  get_epidemiological_events,
38
45
  get_FBD_events,
39
46
  simulate_tree,
@@ -51,27 +58,35 @@ __all__ = [
51
58
  "FBDTreeDatasetGenerator",
52
59
  "SkylineMatrix",
53
60
  "SkylineMatrixCoercible",
54
- "SkylineMatrixModel",
55
61
  "skyline_matrix",
56
62
  "SkylineParameter",
57
63
  "SkylineParameterLike",
58
- "SkylineParameterModel",
59
64
  "skyline_parameter",
60
65
  "SkylineVector",
61
66
  "SkylineVectorCoercible",
62
67
  "SkylineVectorLike",
63
- "SkylineVectorModel",
64
68
  "skyline_vector",
65
69
  "Tree",
66
70
  "TreeDatasetGeneratorConfig",
71
+ "Birth",
72
+ "BirthWithContactTracing",
73
+ "Death",
67
74
  "Event",
75
+ "Migration",
76
+ "Mutation",
77
+ "MutationTargetType",
78
+ "Sampling",
79
+ "SamplingWithContactTracing",
68
80
  "get_BD_events",
69
81
  "get_BDEI_events",
70
82
  "get_BDSS_events",
71
83
  "get_canonical_events",
84
+ "get_contact_tracing_events",
72
85
  "get_epidemiological_events",
73
86
  "get_FBD_events",
87
+ "generate_trees",
74
88
  "simulate_tree",
89
+ "dump_newick",
75
90
  "load_fasta",
76
91
  "load_newick",
77
92
  "MSA",
@@ -3,11 +3,6 @@ from typing import Annotated
3
3
  from pydantic import Field
4
4
 
5
5
  from phylogenie.generators.alisim import AliSimDatasetGenerator
6
- from phylogenie.generators.configs import (
7
- SkylineMatrixModel,
8
- SkylineParameterModel,
9
- SkylineVectorModel,
10
- )
11
6
  from phylogenie.generators.dataset import DatasetGenerator
12
7
  from phylogenie.generators.trees import (
13
8
  BDEITreeDatasetGenerator,
@@ -34,7 +29,4 @@ __all__ = [
34
29
  "BDTreeDatasetGenerator",
35
30
  "BDEITreeDatasetGenerator",
36
31
  "BDSSTreeDatasetGenerator",
37
- "SkylineMatrixModel",
38
- "SkylineParameterModel",
39
- "SkylineVectorModel",
40
32
  ]
@@ -1,12 +1,6 @@
1
- from pydantic import BaseModel, ConfigDict
2
-
3
1
  import phylogenie.typings as pgt
4
-
5
-
6
- class Distribution(BaseModel):
7
- type: str
8
- model_config = ConfigDict(extra="allow")
9
-
2
+ from phylogenie.treesimulator import MutationTargetType
3
+ from phylogenie.utils import Distribution, StrictBaseModel
10
4
 
11
5
  Integer = str | int
12
6
  Scalar = str | pgt.Scalar
@@ -15,10 +9,6 @@ OneOrManyScalars = Scalar | pgt.Many[Scalar]
15
9
  OneOrMany2DScalars = Scalar | pgt.Many2D[Scalar]
16
10
 
17
11
 
18
- class StrictBaseModel(BaseModel):
19
- model_config = ConfigDict(extra="forbid")
20
-
21
-
22
12
  class SkylineParameterModel(StrictBaseModel):
23
13
  value: ManyScalars
24
14
  change_times: ManyScalars
@@ -37,3 +27,12 @@ class SkylineMatrixModel(StrictBaseModel):
37
27
  SkylineParameter = Scalar | SkylineParameterModel
38
28
  SkylineVector = str | pgt.Scalar | pgt.Many[SkylineParameter] | SkylineVectorModel
39
29
  SkylineMatrix = str | pgt.Scalar | pgt.Many[SkylineVector] | SkylineMatrixModel | None
30
+
31
+
32
+ class Event(StrictBaseModel):
33
+ states: str | list[str] | None = None
34
+ rate: SkylineParameter
35
+
36
+
37
+ class Mutation(Event):
38
+ rate_scalers: dict[MutationTargetType, Distribution]
@@ -10,7 +10,8 @@ import pandas as pd
10
10
  from numpy.random import Generator, default_rng
11
11
  from tqdm import tqdm
12
12
 
13
- import phylogenie.generators.configs as cfg
13
+ from phylogenie.generators.factories import distribution
14
+ from phylogenie.utils import Distribution, StrictBaseModel
14
15
 
15
16
 
16
17
  class DataType(str, Enum):
@@ -22,12 +23,12 @@ DATA_DIRNAME = "data"
22
23
  METADATA_FILENAME = "metadata.csv"
23
24
 
24
25
 
25
- class DatasetGenerator(ABC, cfg.StrictBaseModel):
26
+ class DatasetGenerator(ABC, StrictBaseModel):
26
27
  output_dir: str = "phylogenie-outputs"
27
28
  n_samples: int | dict[str, int] = 1
28
29
  n_jobs: int = -1
29
30
  seed: int | None = None
30
- context: dict[str, cfg.Distribution] | None = None
31
+ context: dict[str, Distribution] | None = None
31
32
 
32
33
  @abstractmethod
33
34
  def _generate_one(
@@ -55,19 +56,21 @@ class DatasetGenerator(ABC, cfg.StrictBaseModel):
55
56
  data: list[dict[str, Any]] = [{} for _ in range(n_samples)]
56
57
  if self.context is not None:
57
58
  for d, (k, v) in product(data, self.context.items()):
58
- args = v.model_extra if v.model_extra is not None else {}
59
- d[k] = np.array(getattr(rng, v.type)(**args)).tolist()
59
+ dist = distribution(v, d)
60
+ d[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
60
61
  df = pd.DataFrame([{"file_id": str(i), **d} for i, d in enumerate(data)])
61
62
  df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
62
63
 
63
- joblib.Parallel(n_jobs=self.n_jobs)(
64
+ jobs = joblib.Parallel(n_jobs=self.n_jobs, return_as="generator_unordered")(
64
65
  joblib.delayed(self.generate_one)(
65
66
  filename=os.path.join(data_dir, str(i)),
66
67
  data=data[i],
67
68
  seed=int(rng.integers(2**32)),
68
69
  )
69
- for i in tqdm(range(n_samples), desc=f"Generating {data_dir}...")
70
+ for i in range(n_samples)
70
71
  )
72
+ for _ in tqdm(jobs, total=n_samples, desc=f"Generating {data_dir}..."):
73
+ pass
71
74
 
72
75
  def generate(self) -> None:
73
76
  rng = default_rng(self.seed)
@@ -14,6 +14,7 @@ from phylogenie.skyline import (
14
14
  SkylineVector,
15
15
  SkylineVectorCoercible,
16
16
  )
17
+ from phylogenie.utils import Distribution
17
18
 
18
19
 
19
20
  def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
@@ -29,6 +30,14 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
29
30
  ).tolist()
30
31
 
31
32
 
33
+ def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
34
+ args = x.args
35
+ for arg_name, arg_value in args.items():
36
+ if isinstance(arg_value, str):
37
+ args[arg_name] = _eval_expression(arg_value, data)
38
+ return Distribution(type=x.type, **args)
39
+
40
+
32
41
  def integer(x: cfg.Integer, data: dict[str, Any]) -> int:
33
42
  if isinstance(x, str):
34
43
  e = _eval_expression(x, data)
@@ -9,6 +9,7 @@ from pydantic import Field
9
9
  import phylogenie.generators.configs as cfg
10
10
  from phylogenie.generators.dataset import DatasetGenerator, DataType
11
11
  from phylogenie.generators.factories import (
12
+ distribution,
12
13
  integer,
13
14
  scalar,
14
15
  skyline_matrix,
@@ -19,10 +20,12 @@ from phylogenie.io import dump_newick
19
20
  from phylogenie.tree import Tree
20
21
  from phylogenie.treesimulator import (
21
22
  Event,
23
+ Mutation,
22
24
  get_BD_events,
23
25
  get_BDEI_events,
24
26
  get_BDSS_events,
25
27
  get_canonical_events,
28
+ get_contact_tracing_events,
26
29
  get_epidemiological_events,
27
30
  get_FBD_events,
28
31
  simulate_tree,
@@ -41,16 +44,13 @@ class ParameterizationType(str, Enum):
41
44
  class TreeDatasetGenerator(DatasetGenerator):
42
45
  data_type: Literal[DataType.TREES] = DataType.TREES
43
46
  min_tips: cfg.Integer = 1
44
- max_tips: cfg.Integer | None = None
47
+ max_tips: cfg.Integer = 2**32
45
48
  max_time: cfg.Scalar = np.inf
46
49
  init_state: str | None = None
47
50
  sampling_probability_at_present: cfg.Scalar = 0.0
48
- max_tries: int | None = None
49
- notification_probability: cfg.Scalar = 0.0
50
- max_notified_contacts: cfg.Integer = 1
51
- samplable_states_after_notification: list[str] | None = None
52
- sampling_rate_after_notification: cfg.SkylineParameter = np.inf
53
- contacts_removal_probability: cfg.SkylineParameter = 1
51
+
52
+ @abstractmethod
53
+ def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
54
54
 
55
55
  def simulate_one(self, rng: Generator, data: dict[str, Any]) -> Tree | None:
56
56
  events = self._get_events(data)
@@ -59,34 +59,18 @@ class TreeDatasetGenerator(DatasetGenerator):
59
59
  if self.init_state is None
60
60
  else self.init_state.format(**data)
61
61
  )
62
- max_tips = (
63
- self.max_tips if self.max_tips is None else integer(self.max_tips, data)
64
- )
65
62
  return simulate_tree(
66
63
  events=events,
67
64
  min_tips=integer(self.min_tips, data),
68
- max_tips=max_tips,
65
+ max_tips=integer(self.max_tips, data),
69
66
  max_time=scalar(self.max_time, data),
70
67
  init_state=init_state,
71
68
  sampling_probability_at_present=scalar(
72
69
  self.sampling_probability_at_present, data
73
70
  ),
74
- notification_probability=scalar(self.notification_probability, data),
75
- max_notified_contacts=integer(self.max_notified_contacts, data),
76
- samplable_states_after_notification=self.samplable_states_after_notification,
77
- sampling_rate_after_notification=skyline_parameter(
78
- self.sampling_rate_after_notification, data
79
- ),
80
- contacts_removal_probability=skyline_parameter(
81
- self.contacts_removal_probability, data
82
- ),
83
- max_tries=self.max_tries,
84
71
  seed=int(rng.integers(2**32)),
85
72
  )
86
73
 
87
- @abstractmethod
88
- def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
89
-
90
74
  def _generate_one(
91
75
  self, filename: str, rng: Generator, data: dict[str, Any]
92
76
  ) -> None:
@@ -101,9 +85,9 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
101
85
  )
102
86
  states: list[str]
103
87
  sampling_rates: cfg.SkylineVector
88
+ remove_after_sampling: bool
104
89
  birth_rates: cfg.SkylineVector = 0
105
90
  death_rates: cfg.SkylineVector = 0
106
- removal_probabilities: cfg.SkylineVector = 0
107
91
  migration_rates: cfg.SkylineMatrix = None
108
92
  birth_rates_among_states: cfg.SkylineMatrix = None
109
93
 
@@ -111,9 +95,9 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
111
95
  return get_canonical_events(
112
96
  states=self.states,
113
97
  sampling_rates=skyline_vector(self.sampling_rates, data),
98
+ remove_after_sampling=self.remove_after_sampling,
114
99
  birth_rates=skyline_vector(self.birth_rates, data),
115
100
  death_rates=skyline_vector(self.death_rates, data),
116
- removal_probabilities=skyline_vector(self.removal_probabilities, data),
117
101
  migration_rates=skyline_matrix(self.migration_rates, data),
118
102
  birth_rates_among_states=skyline_matrix(
119
103
  self.birth_rates_among_states, data
@@ -121,19 +105,87 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
121
105
  )
122
106
 
123
107
 
124
- class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
108
+ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
109
+ parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
110
+ states: list[str]
111
+ sampling_proportions: cfg.SkylineVector
112
+ diversification: cfg.SkylineVector = 0
113
+ turnover: cfg.SkylineVector = 0
114
+ migration_rates: cfg.SkylineMatrix = None
115
+ diversification_between_states: cfg.SkylineMatrix = None
116
+
117
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
118
+ return get_FBD_events(
119
+ states=self.states,
120
+ diversification=skyline_vector(self.diversification, data),
121
+ turnover=skyline_vector(self.turnover, data),
122
+ sampling_proportions=skyline_vector(self.sampling_proportions, data),
123
+ migration_rates=skyline_matrix(self.migration_rates, data),
124
+ diversification_between_states=skyline_matrix(
125
+ self.diversification_between_states, data
126
+ ),
127
+ )
128
+
129
+
130
+ class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
131
+ max_notified_contacts: cfg.Integer = 1
132
+ notification_probability: cfg.SkylineParameter = 0.0
133
+ sampling_rate_after_notification: cfg.SkylineParameter = np.inf
134
+ samplable_states_after_notification: list[str] | None = None
135
+ mutations: tuple[cfg.Mutation, ...] = Field(default_factory=tuple)
136
+
137
+ @abstractmethod
138
+ def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
139
+
140
+ def _get_events(self, data: dict[str, Any]) -> list[Event]:
141
+ events = self._get_base_events(data)
142
+ if self.notification_probability:
143
+ events = get_contact_tracing_events(
144
+ events=events,
145
+ max_notified_contacts=integer(self.max_notified_contacts, data),
146
+ notification_probability=skyline_parameter(
147
+ self.notification_probability, data
148
+ ),
149
+ sampling_rate_after_notification=skyline_parameter(
150
+ self.sampling_rate_after_notification, data
151
+ ),
152
+ samplable_states_after_notification=self.samplable_states_after_notification,
153
+ )
154
+ all_states = list({e.state for e in events})
155
+ for mutation in self.mutations:
156
+ states = mutation.states
157
+ if isinstance(states, str):
158
+ states = [states]
159
+ elif states is None:
160
+ states = all_states
161
+ for state in states:
162
+ if state not in all_states:
163
+ raise ValueError(
164
+ f"Mutation state '{state}' is not found in states {all_states}."
165
+ )
166
+ rate_scalers = {
167
+ t: distribution(r, data) for t, r in mutation.rate_scalers.items()
168
+ }
169
+ events.append(
170
+ Mutation(
171
+ state, skyline_parameter(mutation.rate, data), rate_scalers
172
+ )
173
+ )
174
+ return events
175
+
176
+
177
+ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
125
178
  parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
126
179
  ParameterizationType.EPIDEMIOLOGICAL
127
180
  )
128
181
  states: list[str]
182
+ sampling_proportions: cfg.SkylineVector
129
183
  reproduction_numbers: cfg.SkylineVector = 0
130
184
  become_uninfectious_rates: cfg.SkylineVector = 0
131
- sampling_proportions: cfg.SkylineVector = 1
132
- removal_probabilities: cfg.SkylineVector = 1
133
185
  migration_rates: cfg.SkylineMatrix = None
134
186
  reproduction_numbers_among_states: cfg.SkylineMatrix = None
135
187
 
136
- def _get_events(self, data: dict[str, Any]) -> list[Event]:
188
+ def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
137
189
  return get_epidemiological_events(
138
190
  states=self.states,
139
191
  reproduction_numbers=skyline_vector(self.reproduction_numbers, data),
@@ -141,7 +193,6 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
141
193
  self.become_uninfectious_rates, data
142
194
  ),
143
195
  sampling_proportions=skyline_vector(self.sampling_proportions, data),
144
- removal_probabilities=skyline_vector(self.removal_probabilities, data),
145
196
  migration_rates=skyline_matrix(self.migration_rates, data),
146
197
  reproduction_numbers_among_states=skyline_matrix(
147
198
  self.reproduction_numbers_among_states, data
@@ -149,37 +200,13 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGenerator):
149
200
  )
150
201
 
151
202
 
152
- class FBDTreeDatasetGenerator(TreeDatasetGenerator):
153
- parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
154
- states: list[str]
155
- diversification: cfg.SkylineVector = 0
156
- turnover: cfg.SkylineVector = 0
157
- sampling_proportions: cfg.SkylineVector = 1
158
- removal_probabilities: cfg.SkylineVector = 0
159
- migration_rates: cfg.SkylineMatrix = None
160
- diversification_between_types: cfg.SkylineMatrix = None
161
-
162
- def _get_events(self, data: dict[str, Any]) -> list[Event]:
163
- return get_FBD_events(
164
- states=self.states,
165
- diversification=skyline_vector(self.diversification, data),
166
- turnover=skyline_vector(self.turnover, data),
167
- sampling_proportions=skyline_vector(self.sampling_proportions, data),
168
- removal_probabilities=skyline_vector(self.removal_probabilities, data),
169
- migration_rates=skyline_matrix(self.migration_rates, data),
170
- diversification_between_types=skyline_matrix(
171
- self.diversification_between_types, data
172
- ),
173
- )
174
-
175
-
176
- class BDTreeDatasetGenerator(TreeDatasetGenerator):
203
+ class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
177
204
  parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
178
205
  reproduction_number: cfg.SkylineParameter
179
206
  infectious_period: cfg.SkylineParameter
180
207
  sampling_proportion: cfg.SkylineParameter = 1
181
208
 
182
- def _get_events(self, data: dict[str, Any]) -> list[Event]:
209
+ def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
183
210
  return get_BD_events(
184
211
  reproduction_number=skyline_parameter(self.reproduction_number, data),
185
212
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -187,14 +214,14 @@ class BDTreeDatasetGenerator(TreeDatasetGenerator):
187
214
  )
188
215
 
189
216
 
190
- class BDEITreeDatasetGenerator(TreeDatasetGenerator):
217
+ class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
191
218
  parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
192
219
  reproduction_number: cfg.SkylineParameter
193
220
  infectious_period: cfg.SkylineParameter
194
221
  incubation_period: cfg.SkylineParameter
195
222
  sampling_proportion: cfg.SkylineParameter = 1
196
223
 
197
- def _get_events(self, data: dict[str, Any]) -> list[Event]:
224
+ def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
198
225
  return get_BDEI_events(
199
226
  reproduction_number=skyline_parameter(self.reproduction_number, data),
200
227
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -203,7 +230,7 @@ class BDEITreeDatasetGenerator(TreeDatasetGenerator):
203
230
  )
204
231
 
205
232
 
206
- class BDSSTreeDatasetGenerator(TreeDatasetGenerator):
233
+ class BDSSTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
207
234
  parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
208
235
  reproduction_number: cfg.SkylineParameter
209
236
  infectious_period: cfg.SkylineParameter
@@ -211,7 +238,7 @@ class BDSSTreeDatasetGenerator(TreeDatasetGenerator):
211
238
  superspreaders_proportion: cfg.SkylineParameter
212
239
  sampling_proportion: cfg.SkylineParameter = 1
213
240
 
214
- def _get_events(self, data: dict[str, Any]) -> list[Event]:
241
+ def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
215
242
  return get_BDSS_events(
216
243
  reproduction_number=skyline_parameter(self.reproduction_number, data),
217
244
  infectious_period=skyline_parameter(self.infectious_period, data),
@@ -24,21 +24,28 @@ def _parse_newick(newick: str) -> Tree:
24
24
  stack.append(current_nodes)
25
25
  current_nodes = []
26
26
  else:
27
- id = _parse_chars([":", ",", ")", ";"])
27
+ id = _parse_chars([":", ",", ")", ";", "["])
28
28
  branch_length = None
29
29
  if newick[i] == ":":
30
30
  i += 1
31
- branch_length = _parse_chars([",", ")", ";"])
31
+ branch_length = float(_parse_chars([",", ")", ";", "["]))
32
32
 
33
- current_node = Tree(
34
- id=id,
35
- branch_length=(None if branch_length is None else float(branch_length)),
36
- )
33
+ current_node = Tree(id, branch_length)
37
34
  for node in current_children:
38
35
  current_node.add_child(node)
39
36
  current_children = []
40
37
  current_nodes.append(current_node)
41
38
 
39
+ if newick[i] == "[":
40
+ i += 1
41
+ features = _parse_chars(["]"]).split(":")
42
+ i += 1
43
+ if features[0] != "&&NHX":
44
+ raise ValueError(f"Expected '&&NHX' for node features.")
45
+ for feature in features[1:]:
46
+ key, value = feature.split("=", 1)
47
+ current_node.set(key, eval(value))
48
+
42
49
  if newick[i] == ")":
43
50
  current_children = current_nodes
44
51
  current_nodes = stack.pop()
@@ -47,7 +54,7 @@ def _parse_newick(newick: str) -> Tree:
47
54
 
48
55
  i += 1
49
56
 
50
- raise ValueError("Newick string does not end with a semicolon.")
57
+ raise ValueError("Newick string is invalid.")
51
58
 
52
59
 
53
60
  def load_newick(filepath: str) -> Tree | list[Tree]:
@@ -63,12 +70,19 @@ def _to_newick(tree: Tree) -> str:
63
70
  newick = f"({children_newick}){newick}"
64
71
  if tree.branch_length is not None:
65
72
  newick += f":{tree.branch_length}"
73
+ if tree.features:
74
+ reprs = {k: repr(v).replace("'", '"') for k, v in tree.features.items()}
75
+ features = [f"{k}={repr}" for k, repr in reprs.items()]
76
+ newick += f"[&&NHX:{':'.join(features)}]"
66
77
  return newick
67
78
 
68
79
 
69
- def dump_newick(tree: Tree, filepath: str) -> None:
80
+ def dump_newick(trees: Tree | list[Tree], filepath: str) -> None:
81
+ if isinstance(trees, Tree):
82
+ trees = [trees]
70
83
  with open(filepath, "w") as file:
71
- file.write(_to_newick(tree) + ";")
84
+ for t in trees:
85
+ file.write(_to_newick(t) + ";\n")
72
86
 
73
87
 
74
88
  def load_fasta(
@@ -80,13 +94,14 @@ def load_fasta(
80
94
  if not line.startswith(">"):
81
95
  raise ValueError(f"Invalid FASTA format: expected '>', got '{line[0]}'")
82
96
  id = line[1:].strip()
97
+ time = None
83
98
  if extract_time_from_id is not None:
84
99
  time = extract_time_from_id(id)
85
- else:
100
+ elif "|" in id:
86
101
  try:
87
102
  time = float(id.split("|")[-1])
88
- except:
89
- time = None
103
+ except ValueError:
104
+ pass
90
105
  chars = next(f).strip()
91
106
  sequences.append(Sequence(id, chars, time))
92
107
  return MSA(sequences)
@@ -1,18 +1,41 @@
1
1
  from collections.abc import Iterator
2
+ from typing import Any
2
3
 
3
4
 
4
5
  class Tree:
5
6
  def __init__(self, id: str = "", branch_length: float | None = None):
6
7
  self.id = id
7
8
  self.branch_length = branch_length
8
- self.parent: Tree | None = None
9
- self.children: list[Tree] = []
9
+ self._parent: Tree | None = None
10
+ self._children: list[Tree] = []
11
+ self._features: dict[str, Any] = {}
12
+
13
+ @property
14
+ def children(self) -> tuple["Tree", ...]:
15
+ return tuple(self._children)
16
+
17
+ @property
18
+ def parent(self) -> "Tree | None":
19
+ return self._parent
20
+
21
+ @property
22
+ def features(self) -> dict[str, Any]:
23
+ return self._features
10
24
 
11
25
  def add_child(self, child: "Tree") -> "Tree":
12
- child.parent = self
13
- self.children.append(child)
26
+ child._parent = self
27
+ self._children.append(child)
14
28
  return self
15
29
 
30
+ def remove_child(self, child: "Tree") -> None:
31
+ self._children.remove(child)
32
+ child._parent = None
33
+
34
+ def set_parent(self, node: "Tree | None"):
35
+ self._parent = node
36
+ if node is not None:
37
+ node._children.append(self)
38
+
16
39
  def preorder_traversal(self) -> Iterator["Tree"]:
17
40
  yield self
18
41
  for child in self.children:
@@ -29,6 +52,9 @@ class Tree:
29
52
  return node
30
53
  raise ValueError(f"Node with id {id} not found.")
31
54
 
55
+ def is_leaf(self) -> bool:
56
+ return not self.children
57
+
32
58
  def get_leaves(self) -> list["Tree"]:
33
59
  return [node for node in self if not node.children]
34
60
 
@@ -38,17 +64,17 @@ class Tree:
38
64
  raise ValueError(f"Branch length of node {self.id} is not set.")
39
65
  return self.branch_length + parent_time
40
66
 
41
- def is_leaf(self) -> bool:
42
- return not self.children
67
+ def set(self, key: str, value: Any) -> None:
68
+ self._features[key] = value
43
69
 
44
- def copy(self) -> "Tree":
45
- new_tree = Tree(self.id, self.branch_length)
46
- for child in self.children:
47
- new_tree.add_child(child.copy())
48
- return new_tree
70
+ def get(self, key: str) -> Any:
71
+ return self._features.get(key)
72
+
73
+ def delete(self, key: str) -> None:
74
+ del self._features[key]
49
75
 
50
76
  def __iter__(self) -> Iterator["Tree"]:
51
77
  return self.preorder_traversal()
52
78
 
53
79
  def __repr__(self) -> str:
54
- return f"TreeNode(id='{self.id}', branch_length={self.branch_length})"
80
+ return f"TreeNode(id='{self.id}', branch_length={self.branch_length}, features={self.features})"