phylogenie 2.1.4__py3-none-any.whl → 3.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. phylogenie/__init__.py +60 -14
  2. phylogenie/draw.py +690 -0
  3. phylogenie/generators/alisim.py +12 -12
  4. phylogenie/generators/configs.py +26 -4
  5. phylogenie/generators/dataset.py +3 -3
  6. phylogenie/generators/factories.py +38 -12
  7. phylogenie/generators/trees.py +48 -47
  8. phylogenie/io/__init__.py +3 -0
  9. phylogenie/io/fasta.py +34 -0
  10. phylogenie/main.py +27 -10
  11. phylogenie/mixins.py +33 -0
  12. phylogenie/skyline/matrix.py +11 -7
  13. phylogenie/skyline/parameter.py +12 -4
  14. phylogenie/skyline/vector.py +12 -6
  15. phylogenie/treesimulator/__init__.py +36 -3
  16. phylogenie/treesimulator/events/__init__.py +5 -5
  17. phylogenie/treesimulator/events/base.py +39 -0
  18. phylogenie/treesimulator/events/contact_tracing.py +38 -23
  19. phylogenie/treesimulator/events/core.py +21 -12
  20. phylogenie/treesimulator/events/mutations.py +46 -46
  21. phylogenie/treesimulator/features.py +49 -0
  22. phylogenie/treesimulator/gillespie.py +59 -55
  23. phylogenie/treesimulator/io/__init__.py +4 -0
  24. phylogenie/treesimulator/io/newick.py +104 -0
  25. phylogenie/treesimulator/io/nexus.py +50 -0
  26. phylogenie/treesimulator/model.py +25 -49
  27. phylogenie/treesimulator/tree.py +196 -0
  28. phylogenie/treesimulator/utils.py +108 -0
  29. phylogenie/typings.py +3 -3
  30. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/METADATA +13 -15
  31. phylogenie-3.1.7.dist-info/RECORD +41 -0
  32. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/WHEEL +2 -1
  33. phylogenie-3.1.7.dist-info/entry_points.txt +2 -0
  34. phylogenie-3.1.7.dist-info/top_level.txt +1 -0
  35. phylogenie/io.py +0 -107
  36. phylogenie/tree.py +0 -92
  37. phylogenie/utils.py +0 -17
  38. phylogenie-2.1.4.dist-info/RECORD +0 -32
  39. phylogenie-2.1.4.dist-info/entry_points.txt +0 -3
  40. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info/licenses}/LICENSE.txt +0 -0
@@ -6,9 +6,9 @@ from typing import Any, Literal
6
6
  from numpy.random import Generator, default_rng
7
7
 
8
8
  from phylogenie.generators.dataset import DatasetGenerator, DataType
9
- from phylogenie.generators.factories import data
9
+ from phylogenie.generators.factories import data, string
10
10
  from phylogenie.generators.trees import TreeDatasetGeneratorConfig
11
- from phylogenie.io import dump_newick
11
+ from phylogenie.treesimulator import dump_newick, get_node_depths
12
12
 
13
13
  MSAS_DIRNAME = "MSAs"
14
14
  TREES_DIRNAME = "trees"
@@ -19,7 +19,7 @@ class AliSimDatasetGenerator(DatasetGenerator):
19
19
  trees: TreeDatasetGeneratorConfig
20
20
  keep_trees: bool = False
21
21
  iqtree_path: str = "iqtree2"
22
- args: dict[str, str | int | float]
22
+ args: dict[str, Any]
23
23
 
24
24
  def _generate_one_from_tree(
25
25
  self, filename: str, tree_file: str, rng: Generator, data: dict[str, Any]
@@ -35,9 +35,7 @@ class AliSimDatasetGenerator(DatasetGenerator):
35
35
  ]
36
36
 
37
37
  for key, value in self.args.items():
38
- command.extend(
39
- [key, value.format(**data) if isinstance(value, str) else str(value)]
40
- )
38
+ command.extend([key, string(value, data)])
41
39
 
42
40
  command.extend(["-af", "fasta"])
43
41
  subprocess.run(command, check=True, stdout=subprocess.DEVNULL)
@@ -61,24 +59,26 @@ class AliSimDatasetGenerator(DatasetGenerator):
61
59
  tree_filename = f"{filename}.temp-tree"
62
60
  msa_filename = filename
63
61
 
64
- d: dict[str, Any] = {"file_id": Path(msa_filename).stem}
62
+ md: dict[str, Any] = {"file_id": Path(msa_filename).stem}
65
63
  rng = default_rng(seed)
66
64
  while True:
67
- d.update(data(context, rng))
65
+ md.update(data(context, rng))
68
66
  try:
69
- tree = self.trees.simulate_one(d, seed)
67
+ tree, metadata = self.trees.simulate_one(md, seed)
70
68
  break
71
69
  except TimeoutError:
72
70
  print(
73
71
  "Tree simulation timed out, retrying with different parameters..."
74
72
  )
73
+ md.update(metadata)
75
74
 
75
+ times = get_node_depths(tree)
76
76
  for leaf in tree.get_leaves():
77
- leaf.id += f"|{leaf.get_time()}"
77
+ leaf.name += f"|{times[leaf]}"
78
78
  dump_newick(tree, f"{tree_filename}.nwk")
79
79
 
80
- self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, d)
80
+ self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, md)
81
81
  if not self.keep_trees:
82
82
  os.remove(f"{tree_filename}.nwk")
83
83
 
84
- return d
84
+ return md
@@ -1,6 +1,28 @@
1
+ from typing import Any
2
+
3
+ from numpy.random import Generator
4
+ from pydantic import BaseModel, ConfigDict
5
+
1
6
  import phylogenie.typings as pgt
2
- from phylogenie.treesimulator import MutationTargetType
3
- from phylogenie.utils import Distribution, StrictBaseModel
7
+ from phylogenie.treesimulator import EventType
8
+
9
+
10
+ class StrictBaseModel(BaseModel):
11
+ model_config = ConfigDict(extra="forbid")
12
+
13
+
14
+ class Distribution(BaseModel):
15
+ type: str
16
+ model_config = ConfigDict(extra="allow")
17
+
18
+ @property
19
+ def args(self) -> dict[str, Any]:
20
+ assert self.model_extra is not None
21
+ return self.model_extra
22
+
23
+ def __call__(self, rng: Generator) -> Any:
24
+ return getattr(rng, self.type)(**self.args)
25
+
4
26
 
5
27
  Integer = str | int
6
28
  Scalar = str | pgt.Scalar
@@ -30,9 +52,9 @@ SkylineMatrix = str | pgt.Scalar | pgt.Many[SkylineVector] | SkylineMatrixModel
30
52
 
31
53
 
32
54
  class Event(StrictBaseModel):
33
- states: str | list[str] | None = None
55
+ state: str | None = None
34
56
  rate: SkylineParameter
35
57
 
36
58
 
37
59
  class Mutation(Event):
38
- rate_scalers: dict[MutationTargetType, Distribution]
60
+ rate_scalers: dict[EventType, Distribution]
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  from numpy.random import Generator, default_rng
9
9
  from tqdm import tqdm
10
10
 
11
- from phylogenie.utils import Distribution, StrictBaseModel
11
+ from phylogenie.generators.configs import Distribution, StrictBaseModel
12
12
 
13
13
 
14
14
  class DataType(str, Enum):
@@ -31,7 +31,7 @@ class DatasetGenerator(ABC, StrictBaseModel):
31
31
  def generate_one(
32
32
  self,
33
33
  filename: str,
34
- context: dict[str, Any] | None = None,
34
+ context: dict[str, Distribution] | None = None,
35
35
  seed: int | None = None,
36
36
  ) -> dict[str, Any]: ...
37
37
 
@@ -56,7 +56,7 @@ class DatasetGenerator(ABC, StrictBaseModel):
56
56
  for i in range(n_samples)
57
57
  )
58
58
  df = pd.DataFrame(
59
- [r for r in tqdm(jobs, total=n_samples, desc=f"Generating {data_dir}...")]
59
+ [j for j in tqdm(jobs, f"Generating {data_dir}...", n_samples)]
60
60
  )
61
61
  df.to_csv(os.path.join(output_dir, METADATA_FILENAME), index=False)
62
62
 
@@ -1,4 +1,5 @@
1
- from typing import Any
1
+ import re
2
+ from typing import Any, Callable
2
3
 
3
4
  import numpy as np
4
5
  from numpy.random import Generator
@@ -15,18 +16,14 @@ from phylogenie.skyline import (
15
16
  SkylineVector,
16
17
  SkylineVectorCoercible,
17
18
  )
18
- from phylogenie.utils import Distribution
19
+ from phylogenie.treesimulator import EventType, Mutation
19
20
 
20
21
 
21
22
  def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
22
23
  return np.array(
23
24
  eval(
24
25
  expression,
25
- {
26
- "__builtins__": __builtins__,
27
- "np": np,
28
- **{k: np.array(v) for k, v in data.items()},
29
- },
26
+ {"np": np, **{k: np.array(v) for k, v in data.items()}},
30
27
  )
31
28
  ).tolist()
32
29
 
@@ -53,6 +50,14 @@ def scalar(x: cfg.Scalar, data: dict[str, Any]) -> pgt.Scalar:
53
50
  return x
54
51
 
55
52
 
53
+ def string(s: Any, data: dict[str, Any]) -> str:
54
+ if not isinstance(s, str):
55
+ return str(s)
56
+ return re.sub(
57
+ r"\{([^{}]+)\}", lambda match: str(_eval_expression(match.group(1), data)), s
58
+ ) # Match content inside curly braces
59
+
60
+
56
61
  def many_scalars(x: cfg.ManyScalars, data: dict[str, Any]) -> pgt.ManyScalars:
57
62
  if isinstance(x, str):
58
63
  e = _eval_expression(x, data)
@@ -204,19 +209,40 @@ def skyline_matrix(
204
209
  return SkylineMatrix(value=value, change_times=change_times)
205
210
 
206
211
 
207
- def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
212
+ def distribution(x: cfg.Distribution, data: dict[str, Any]) -> cfg.Distribution:
208
213
  args = x.args
209
214
  for arg_name, arg_value in args.items():
210
215
  if isinstance(arg_value, str):
211
216
  args[arg_name] = _eval_expression(arg_value, data)
212
- return Distribution(type=x.type, **args)
217
+ return cfg.Distribution(type=x.type, **args)
218
+
219
+
220
+ def mutations(
221
+ x: list[cfg.Mutation],
222
+ data: dict[str, Any],
223
+ states: set[str],
224
+ rates_to_log: list[EventType] | None,
225
+ rng: Generator,
226
+ ) -> list[Mutation]:
227
+ mutations: list[Mutation] = []
228
+ for m in x:
229
+ rate = skyline_parameter(m.rate, data)
230
+ rate_scalers: dict[EventType, Callable[[], float]] = {
231
+ k: lambda: distribution(v, data)(rng) for k, v in m.rate_scalers.items()
232
+ }
233
+ if m.state is None:
234
+ mutations.extend(
235
+ Mutation(s, rate, rate_scalers, rates_to_log) for s in states
236
+ )
237
+ else:
238
+ mutations.append(Mutation(m.state, rate, rate_scalers, rates_to_log))
239
+ return mutations
213
240
 
214
241
 
215
- def data(context: dict[str, Distribution] | None, rng: Generator) -> dict[str, Any]:
242
+ def data(context: dict[str, cfg.Distribution] | None, rng: Generator) -> dict[str, Any]:
216
243
  if context is None:
217
244
  return {}
218
245
  data: dict[str, Any] = {}
219
246
  for k, v in context.items():
220
- dist = distribution(v, data)
221
- data[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
247
+ data[k] = np.array(distribution(v, data)(rng)).tolist()
222
248
  return data
@@ -1,28 +1,30 @@
1
1
  from abc import abstractmethod
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Annotated, Any, Literal
4
+ from typing import Annotated, Any, Callable, Literal
5
5
 
6
6
  import numpy as np
7
7
  from numpy.random import default_rng
8
8
  from pydantic import Field
9
9
 
10
10
  import phylogenie.generators.configs as cfg
11
+ from phylogenie.generators.configs import Distribution
11
12
  from phylogenie.generators.dataset import DatasetGenerator, DataType
12
13
  from phylogenie.generators.factories import (
13
14
  data,
14
- distribution,
15
15
  integer,
16
+ mutations,
16
17
  scalar,
17
18
  skyline_matrix,
18
19
  skyline_parameter,
19
20
  skyline_vector,
20
21
  )
21
- from phylogenie.io import dump_newick
22
- from phylogenie.tree import Tree
23
22
  from phylogenie.treesimulator import (
24
23
  Event,
25
- Mutation,
24
+ EventType,
25
+ Feature,
26
+ Tree,
27
+ dump_newick,
26
28
  get_BD_events,
27
29
  get_BDEI_events,
28
30
  get_BDSS_events,
@@ -30,6 +32,7 @@ from phylogenie.treesimulator import (
30
32
  get_contact_tracing_events,
31
33
  get_epidemiological_events,
32
34
  get_FBD_events,
35
+ set_features,
33
36
  simulate_tree,
34
37
  )
35
38
 
@@ -45,26 +48,42 @@ class ParameterizationType(str, Enum):
45
48
 
46
49
  class TreeDatasetGenerator(DatasetGenerator):
47
50
  data_type: Literal[DataType.TREES] = DataType.TREES
48
- min_tips: cfg.Integer = 1
49
- max_tips: cfg.Integer = 2**32
51
+ mutations: list[cfg.Mutation] = Field(default_factory=lambda: [])
52
+ rates_to_log: list[EventType] | None = None
53
+ n_tips: cfg.Integer | None = None
50
54
  max_time: cfg.Scalar = np.inf
51
55
  init_state: str | None = None
52
56
  sampling_probability_at_present: cfg.Scalar = 0.0
53
57
  timeout: float = np.inf
58
+ node_features: list[Feature] | None = None
59
+ acceptance_criterion: str | None = None
54
60
 
55
61
  @abstractmethod
56
62
  def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
57
63
 
58
- def simulate_one(self, data: dict[str, Any], seed: int | None = None) -> Tree:
64
+ def simulate_one(
65
+ self, data: dict[str, Any], seed: int | None = None
66
+ ) -> tuple[Tree, dict[str, Any]]:
59
67
  init_state = (
60
68
  self.init_state
61
69
  if self.init_state is None
62
70
  else self.init_state.format(**data)
63
71
  )
72
+ events = self._get_events(data)
73
+ states = {e.state for e in events}
74
+ events += mutations(
75
+ self.mutations, data, states, self.rates_to_log, default_rng(seed)
76
+ )
77
+ acceptance_criterion: None | Callable[[Tree], bool] = (
78
+ None
79
+ if self.acceptance_criterion is None
80
+ else lambda tree: eval(
81
+ self.acceptance_criterion, {}, {"tree": tree} # pyright: ignore
82
+ )
83
+ )
64
84
  return simulate_tree(
65
- events=self._get_events(data),
66
- min_tips=integer(self.min_tips, data),
67
- max_tips=integer(self.max_tips, data),
85
+ events=events,
86
+ n_tips=None if self.n_tips is None else integer(self.n_tips, data),
68
87
  max_time=scalar(self.max_time, data),
69
88
  init_state=init_state,
70
89
  sampling_probability_at_present=scalar(
@@ -72,12 +91,13 @@ class TreeDatasetGenerator(DatasetGenerator):
72
91
  ),
73
92
  seed=seed,
74
93
  timeout=self.timeout,
94
+ acceptance_criterion=acceptance_criterion,
75
95
  )
76
96
 
77
97
  def generate_one(
78
98
  self,
79
99
  filename: str,
80
- context: dict[str, Any] | None = None,
100
+ context: dict[str, Distribution] | None = None,
81
101
  seed: int | None = None,
82
102
  ) -> dict[str, Any]:
83
103
  d = {"file_id": Path(filename).stem}
@@ -85,12 +105,14 @@ class TreeDatasetGenerator(DatasetGenerator):
85
105
  while True:
86
106
  try:
87
107
  d.update(data(context, rng))
88
- tree = self.simulate_one(d, seed)
108
+ tree, metadata = self.simulate_one(d, seed)
109
+ if self.node_features is not None:
110
+ set_features(tree, self.node_features)
89
111
  dump_newick(tree, f"{filename}.nwk")
90
112
  break
91
113
  except TimeoutError:
92
114
  print("Simulation timed out, retrying with different parameters...")
93
- return d
115
+ return d | metadata
94
116
 
95
117
 
96
118
  class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
@@ -98,8 +120,8 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
98
120
  ParameterizationType.CANONICAL
99
121
  )
100
122
  states: list[str]
101
- sampling_rates: cfg.SkylineVector
102
- remove_after_sampling: bool
123
+ sampling_rates: cfg.SkylineVector = 0
124
+ remove_after_sampling: bool = False
103
125
  birth_rates: cfg.SkylineVector = 0
104
126
  death_rates: cfg.SkylineVector = 0
105
127
  migration_rates: cfg.SkylineMatrix = None
@@ -122,7 +144,7 @@ class CanonicalTreeDatasetGenerator(TreeDatasetGenerator):
122
144
  class FBDTreeDatasetGenerator(TreeDatasetGenerator):
123
145
  parameterization: Literal[ParameterizationType.FBD] = ParameterizationType.FBD
124
146
  states: list[str]
125
- sampling_proportions: cfg.SkylineVector
147
+ sampling_proportions: cfg.SkylineVector = 0
126
148
  diversification: cfg.SkylineVector = 0
127
149
  turnover: cfg.SkylineVector = 0
128
150
  migration_rates: cfg.SkylineMatrix = None
@@ -141,12 +163,11 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
141
163
  )
142
164
 
143
165
 
144
- class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
166
+ class ContactTracingTreeDatasetGenerator(TreeDatasetGenerator):
145
167
  max_notified_contacts: cfg.Integer = 1
146
168
  notification_probability: cfg.SkylineParameter = 0.0
147
- sampling_rate_after_notification: cfg.SkylineParameter = np.inf
169
+ sampling_rate_after_notification: cfg.SkylineParameter = 2**32
148
170
  samplable_states_after_notification: list[str] | None = None
149
- mutations: tuple[cfg.Mutation, ...] = Field(default_factory=tuple)
150
171
 
151
172
  @abstractmethod
152
173
  def _get_base_events(self, data: dict[str, Any]) -> list[Event]: ...
@@ -165,30 +186,10 @@ class TreeDatasetGeneratorForEpidemiology(TreeDatasetGenerator):
165
186
  ),
166
187
  samplable_states_after_notification=self.samplable_states_after_notification,
167
188
  )
168
- all_states = list({e.state for e in events})
169
- for mutation in self.mutations:
170
- states = mutation.states
171
- if isinstance(states, str):
172
- states = [states]
173
- elif states is None:
174
- states = all_states
175
- for state in states:
176
- if state not in all_states:
177
- raise ValueError(
178
- f"Mutation state '{state}' is not found in states {all_states}."
179
- )
180
- rate_scalers = {
181
- t: distribution(r, data) for t, r in mutation.rate_scalers.items()
182
- }
183
- events.append(
184
- Mutation(
185
- state, skyline_parameter(mutation.rate, data), rate_scalers
186
- )
187
- )
188
189
  return events
189
190
 
190
191
 
191
- class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
192
+ class EpidemiologicalTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
192
193
  parameterization: Literal[ParameterizationType.EPIDEMIOLOGICAL] = (
193
194
  ParameterizationType.EPIDEMIOLOGICAL
194
195
  )
@@ -214,11 +215,11 @@ class EpidemiologicalTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
214
215
  )
215
216
 
216
217
 
217
- class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
218
+ class BDTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
218
219
  parameterization: Literal[ParameterizationType.BD] = ParameterizationType.BD
219
220
  reproduction_number: cfg.SkylineParameter
220
221
  infectious_period: cfg.SkylineParameter
221
- sampling_proportion: cfg.SkylineParameter = 1
222
+ sampling_proportion: cfg.SkylineParameter
222
223
 
223
224
  def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
224
225
  return get_BD_events(
@@ -228,12 +229,12 @@ class BDTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
228
229
  )
229
230
 
230
231
 
231
- class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
232
+ class BDEITreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
232
233
  parameterization: Literal[ParameterizationType.BDEI] = ParameterizationType.BDEI
233
234
  reproduction_number: cfg.SkylineParameter
234
235
  infectious_period: cfg.SkylineParameter
235
236
  incubation_period: cfg.SkylineParameter
236
- sampling_proportion: cfg.SkylineParameter = 1
237
+ sampling_proportion: cfg.SkylineParameter
237
238
 
238
239
  def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
239
240
  return get_BDEI_events(
@@ -244,13 +245,13 @@ class BDEITreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
244
245
  )
245
246
 
246
247
 
247
- class BDSSTreeDatasetGenerator(TreeDatasetGeneratorForEpidemiology):
248
+ class BDSSTreeDatasetGenerator(ContactTracingTreeDatasetGenerator):
248
249
  parameterization: Literal[ParameterizationType.BDSS] = ParameterizationType.BDSS
249
250
  reproduction_number: cfg.SkylineParameter
250
251
  infectious_period: cfg.SkylineParameter
251
252
  superspreading_ratio: cfg.SkylineParameter
252
253
  superspreaders_proportion: cfg.SkylineParameter
253
- sampling_proportion: cfg.SkylineParameter = 1
254
+ sampling_proportion: cfg.SkylineParameter
254
255
 
255
256
  def _get_base_events(self, data: dict[str, Any]) -> list[Event]:
256
257
  return get_BDSS_events(
@@ -0,0 +1,3 @@
1
+ from phylogenie.io.fasta import dump_fasta, load_fasta
2
+
3
+ __all__ = ["load_fasta", "dump_fasta"]
phylogenie/io/fasta.py ADDED
@@ -0,0 +1,34 @@
1
+ from pathlib import Path
2
+ from typing import Callable
3
+
4
+ from phylogenie.msa import MSA, Sequence
5
+
6
+
7
+ def load_fasta(
8
+ fasta_file: str | Path, extract_time_from_id: Callable[[str], float] | None = None
9
+ ) -> MSA:
10
+ sequences: list[Sequence] = []
11
+ with open(fasta_file, "r") as f:
12
+ for line in f:
13
+ if not line.startswith(">"):
14
+ raise ValueError(f"Invalid FASTA format: expected '>', got '{line[0]}'")
15
+ id = line[1:].strip()
16
+ time = None
17
+ if extract_time_from_id is not None:
18
+ time = extract_time_from_id(id)
19
+ elif "|" in id:
20
+ try:
21
+ time = float(id.split("|")[-1])
22
+ except ValueError:
23
+ pass
24
+ chars = next(f).strip()
25
+ sequences.append(Sequence(id, chars, time))
26
+ return MSA(sequences)
27
+
28
+
29
+ def dump_fasta(msa: MSA | list[Sequence], fasta_file: str | Path) -> None:
30
+ with open(fasta_file, "w") as f:
31
+ sequences = msa.sequences if isinstance(msa, MSA) else msa
32
+ for seq in sequences:
33
+ f.write(f">{seq.id}\n")
34
+ f.write(f"{seq.chars}\n")
phylogenie/main.py CHANGED
@@ -2,27 +2,44 @@ import os
2
2
  from argparse import ArgumentParser
3
3
  from glob import glob
4
4
 
5
- from pydantic import TypeAdapter
5
+ from pydantic import TypeAdapter, ValidationError
6
6
  from yaml import safe_load
7
7
 
8
8
  from phylogenie.generators import DatasetGeneratorConfig
9
9
  from phylogenie.generators.dataset import DatasetGenerator
10
10
 
11
11
 
12
- def run(config_path: str) -> None:
12
+ def _format_validation_error(e: ValidationError) -> str:
13
+ formatted_errors = [
14
+ f"- {'.'.join(str(loc) for loc in err['loc'])}: {err['msg']} ({err['type']})"
15
+ for err in e.errors()
16
+ ]
17
+ return "\n".join(formatted_errors)
18
+
19
+
20
+ def _generate_from_config_file(config_file: str):
13
21
  adapter: TypeAdapter[DatasetGenerator] = TypeAdapter(DatasetGeneratorConfig)
22
+ with open(config_file, "r") as f:
23
+ try:
24
+ config = safe_load(f)
25
+ except Exception as e:
26
+ print(f"❌ Failed to parse {config_file}: {e}")
27
+ exit(-1)
28
+ try:
29
+ generator = adapter.validate_python(config)
30
+ except ValidationError as e:
31
+ print("❌ Invalid configuration:")
32
+ print(_format_validation_error(e))
33
+ exit(-1)
34
+ generator.generate()
14
35
 
36
+
37
+ def run(config_path: str) -> None:
15
38
  if os.path.isdir(config_path):
16
39
  for config_file in glob(os.path.join(config_path, "**/*.yaml"), recursive=True):
17
- with open(config_file, "r") as f:
18
- config = safe_load(f)
19
- generator = adapter.validate_python(config)
20
- generator.generate()
40
+ _generate_from_config_file(config_file)
21
41
  else:
22
- with open(config_path, "r") as f:
23
- config = safe_load(f)
24
- generator = adapter.validate_python(config)
25
- generator.generate()
42
+ _generate_from_config_file(config_path)
26
43
 
27
44
 
28
45
  def main() -> None:
phylogenie/mixins.py ADDED
@@ -0,0 +1,33 @@
1
+ from collections.abc import Mapping
2
+ from types import MappingProxyType
3
+ from typing import Any
4
+
5
+
6
+ class MetadataMixin:
7
+ def __init__(self) -> None:
8
+ self._metadata: dict[str, Any] = {}
9
+
10
+ @property
11
+ def metadata(self) -> Mapping[str, Any]:
12
+ return MappingProxyType(self._metadata)
13
+
14
+ def set(self, key: str, value: Any) -> None:
15
+ self._metadata[key] = value
16
+
17
+ def update(self, metadata: Mapping[str, Any]) -> None:
18
+ self._metadata.update(metadata)
19
+
20
+ def get(self, key: str, default: Any = None) -> Any:
21
+ return self._metadata.get(key, default)
22
+
23
+ def delete(self, key: str) -> None:
24
+ self._metadata.pop(key, None)
25
+
26
+ def clear(self) -> None:
27
+ self._metadata.clear()
28
+
29
+ def __getitem__(self, key: str) -> Any:
30
+ return self._metadata[key]
31
+
32
+ def __setitem__(self, key: str, value: Any) -> None:
33
+ self._metadata[key] = value
@@ -33,7 +33,7 @@ class SkylineMatrix:
33
33
  ):
34
34
  if params is not None and value is None and change_times is None:
35
35
  if is_many_skyline_vectors_like(params):
36
- self.params = [
36
+ self._params = [
37
37
  p if isinstance(p, SkylineVector) else SkylineVector(p)
38
38
  for p in params
39
39
  ]
@@ -41,7 +41,7 @@ class SkylineMatrix:
41
41
  raise TypeError(
42
42
  f"It is impossible to create a SkylineMatrix from `params` {params} of type {type(params)}. Please provide a sequence composed of SkylineVectorLike objects (a SkylineVectorLike object can either be a SkylineVector or a sequence of scalars and/or SkylineParameters)."
43
43
  )
44
- lengths = {len(p) for p in self.params}
44
+ lengths = {len(p) for p in self._params}
45
45
  if len(lengths) > 1:
46
46
  raise ValueError(
47
47
  f"All `params` must have the same length to create a SkylineMatrix (got params={params} with lengths {lengths})."
@@ -57,7 +57,7 @@ class SkylineMatrix:
57
57
  raise TypeError(
58
58
  f"It is impossible to create a SkylineMatrix from `value` {value} of type {type(value)}. Please provide a nested (3D) sequence of scalar values."
59
59
  )
60
- self.params = [
60
+ self._params = [
61
61
  SkylineVector(
62
62
  value=[matrix[i] for matrix in value], change_times=change_times
63
63
  )
@@ -68,6 +68,10 @@ class SkylineMatrix:
68
68
  "Either `params` or both `value` and `change_times` must be provided to create a SkylineMatrix."
69
69
  )
70
70
 
71
+ @property
72
+ def params(self) -> tuple[SkylineVector, ...]:
73
+ return tuple(self._params)
74
+
71
75
  @property
72
76
  def n_rows(self) -> int:
73
77
  return len(self.params)
@@ -82,14 +86,14 @@ class SkylineMatrix:
82
86
 
83
87
  @property
84
88
  def change_times(self) -> pgt.Vector1D:
85
- return sorted(set([t for row in self.params for t in row.change_times]))
89
+ return tuple(sorted(set([t for row in self.params for t in row.change_times])))
86
90
 
87
91
  @property
88
92
  def value(self) -> pgt.Vector3D:
89
- return [self.get_value_at_time(t) for t in (0, *self.change_times)]
93
+ return tuple(self.get_value_at_time(t) for t in (0, *self.change_times))
90
94
 
91
95
  def get_value_at_time(self, time: pgt.Scalar) -> pgt.Vector2D:
92
- return [param.get_value_at_time(time) for param in self.params]
96
+ return tuple(param.get_value_at_time(time) for param in self.params)
93
97
 
94
98
  def _operate(
95
99
  self,
@@ -185,7 +189,7 @@ class SkylineMatrix:
185
189
  raise TypeError(
186
190
  f"It is impossible to set item of SkylineMatrix to value {value} of type {type(value)}. Please provide a SkylineVectorLike object (i.e., a SkylineVector or a sequence of scalars and/or SkylineParameters)."
187
191
  )
188
- self.params[item] = skyline_vector(value, self.n_cols)
192
+ self._params[item] = skyline_vector(value, self.n_cols)
189
193
 
190
194
 
191
195
  def skyline_matrix(
@@ -52,12 +52,20 @@ class SkylineParameter:
52
52
  f"`change_times` must be non-negative (got change_times={change_times})."
53
53
  )
54
54
 
55
- self.value = [value[0]]
56
- self.change_times: list[pgt.Scalar] = []
55
+ self._value = [value[0]]
56
+ self._change_times: list[pgt.Scalar] = []
57
57
  for i in range(1, len(value)):
58
58
  if value[i] != value[i - 1]:
59
- self.value.append(value[i])
60
- self.change_times.append(change_times[i - 1])
59
+ self._value.append(value[i])
60
+ self._change_times.append(change_times[i - 1])
61
+
62
+ @property
63
+ def value(self) -> pgt.Vector1D:
64
+ return tuple(self._value)
65
+
66
+ @property
67
+ def change_times(self) -> pgt.Vector1D:
68
+ return tuple(self._change_times)
61
69
 
62
70
  def get_value_at_time(self, t: pgt.Scalar) -> pgt.Scalar:
63
71
  if t < 0: