phylogenie 2.1.30__tar.gz → 3.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {phylogenie-2.1.30 → phylogenie-3.1.5}/PKG-INFO +11 -16
  2. phylogenie-3.1.5/pyproject.toml +32 -0
  3. phylogenie-3.1.5/setup.cfg +4 -0
  4. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/__init__.py +35 -28
  5. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/draw.py +47 -48
  6. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/alisim.py +7 -8
  7. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/configs.py +23 -1
  8. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/dataset.py +1 -1
  9. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/factories.py +10 -13
  10. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/trees.py +19 -10
  11. phylogenie-3.1.5/src/phylogenie/io/__init__.py +3 -0
  12. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/main.py +27 -10
  13. phylogenie-3.1.5/src/phylogenie/mixins.py +41 -0
  14. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/__init__.py +29 -3
  15. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/events/base.py +0 -3
  16. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/events/contact_tracing.py +8 -9
  17. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/events/mutations.py +7 -8
  18. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/features.py +3 -3
  19. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/gillespie.py +21 -37
  20. phylogenie-3.1.5/src/phylogenie/treesimulator/io/__init__.py +4 -0
  21. {phylogenie-2.1.30/phylogenie → phylogenie-3.1.5/src/phylogenie/treesimulator}/io/newick.py +3 -3
  22. {phylogenie-2.1.30/phylogenie → phylogenie-3.1.5/src/phylogenie/treesimulator}/io/nexus.py +12 -7
  23. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/model.py +7 -10
  24. {phylogenie-2.1.30/phylogenie → phylogenie-3.1.5/src/phylogenie/treesimulator}/tree.py +110 -84
  25. phylogenie-3.1.5/src/phylogenie/treesimulator/utils.py +108 -0
  26. phylogenie-3.1.5/src/phylogenie.egg-info/PKG-INFO +101 -0
  27. phylogenie-3.1.5/src/phylogenie.egg-info/SOURCES.txt +44 -0
  28. phylogenie-3.1.5/src/phylogenie.egg-info/dependency_links.txt +1 -0
  29. phylogenie-3.1.5/src/phylogenie.egg-info/entry_points.txt +2 -0
  30. phylogenie-3.1.5/src/phylogenie.egg-info/requires.txt +6 -0
  31. phylogenie-3.1.5/src/phylogenie.egg-info/top_level.txt +1 -0
  32. phylogenie-2.1.30/phylogenie/io/__init__.py +0 -5
  33. phylogenie-2.1.30/phylogenie/models.py +0 -17
  34. phylogenie-2.1.30/phylogenie/utils.py +0 -176
  35. phylogenie-2.1.30/pyproject.toml +0 -32
  36. {phylogenie-2.1.30 → phylogenie-3.1.5}/LICENSE.txt +0 -0
  37. {phylogenie-2.1.30 → phylogenie-3.1.5}/README.md +0 -0
  38. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/__init__.py +0 -0
  39. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/generators/typeguards.py +0 -0
  40. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/io/fasta.py +0 -0
  41. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/msa.py +0 -0
  42. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/py.typed +0 -0
  43. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/skyline/__init__.py +0 -0
  44. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/skyline/matrix.py +0 -0
  45. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/skyline/parameter.py +0 -0
  46. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/skyline/vector.py +0 -0
  47. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/events/__init__.py +0 -0
  48. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/treesimulator/events/core.py +0 -0
  49. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/typeguards.py +0 -0
  50. {phylogenie-2.1.30 → phylogenie-3.1.5/src}/phylogenie/typings.py +0 -0
@@ -1,21 +1,17 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: phylogenie
3
- Version: 2.1.30
3
+ Version: 3.1.5
4
4
  Summary: Generate phylogenetic datasets with minimal setup effort
5
- Author: Gabriele Marino
6
- Author-email: gabmarino.8601@gmail.com
7
- Requires-Python: >=3.10,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.10
10
- Classifier: Programming Language :: Python :: 3.11
11
- Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: joblib (>=1.4.2,<2.0.0)
13
- Requires-Dist: matplotlib (>=3.10.6,<4.0.0)
14
- Requires-Dist: pandas (>=2.2.2,<3.0.0)
15
- Requires-Dist: pydantic (>=2.11.5,<3.0.0)
16
- Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
17
- Requires-Dist: tqdm (>=4.66.4,<5.0.0)
5
+ Requires-Python: >=3.10
18
6
  Description-Content-Type: text/markdown
7
+ License-File: LICENSE.txt
8
+ Requires-Dist: joblib>=1.5.2
9
+ Requires-Dist: matplotlib>=3.10.7
10
+ Requires-Dist: pandas>=2.3.3
11
+ Requires-Dist: pydantic>=2.12.3
12
+ Requires-Dist: pyyaml>=6.0.3
13
+ Requires-Dist: tqdm>=4.67.1
14
+ Dynamic: license-file
19
15
 
20
16
  <p align="center">
21
17
  <img src="https://raw.githubusercontent.com/gabriele-marino/phylogenie/main/logo.png" style="width:100%; height:auto;"/>
@@ -103,4 +99,3 @@ This project is licensed under [MIT License](https://raw.githubusercontent.com/g
103
99
  For questions, bug reports, or feature requests, please, consider opening an [issue on GitHub](https://github.com/gabriele-marino/phylogenie/issues), or [contact me directly](mailto:gabmarino.8601@email.com).
104
100
 
105
101
  If you need help with the configuration files, feel free to reach out — I am always very available and happy to assist!
106
-
@@ -0,0 +1,32 @@
1
+ [project]
2
+ name = "phylogenie"
3
+ version = "3.1.5"
4
+ description = "Generate phylogenetic datasets with minimal setup effort"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "joblib>=1.5.2",
9
+ "matplotlib>=3.10.7",
10
+ "pandas>=2.3.3",
11
+ "pydantic>=2.12.3",
12
+ "pyyaml>=6.0.3",
13
+ "tqdm>=4.67.1",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ "joblib-stubs>=1.5.2.0.20250831",
19
+ "pandas-stubs>=2.3.2.250926",
20
+ "pyright>=1.1.407",
21
+ "pytest>=8.4.2",
22
+ ]
23
+
24
+ [tool.pyright]
25
+ typeCheckingMode = "strict"
26
+
27
+ [project.scripts]
28
+ phylogenie = "phylogenie.main:main"
29
+
30
+ [build-system]
31
+ requires = ["setuptools>=42"]
32
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -11,7 +11,7 @@ from phylogenie.generators import (
11
11
  FBDTreeDatasetGenerator,
12
12
  TreeDatasetGeneratorConfig,
13
13
  )
14
- from phylogenie.io import dump_newick, load_fasta, load_newick, load_nexus
14
+ from phylogenie.io import load_fasta
15
15
  from phylogenie.msa import MSA
16
16
  from phylogenie.skyline import (
17
17
  SkylineMatrix,
@@ -25,41 +25,47 @@ from phylogenie.skyline import (
25
25
  skyline_parameter,
26
26
  skyline_vector,
27
27
  )
28
- from phylogenie.tree import Tree
29
28
  from phylogenie.treesimulator import (
30
29
  Birth,
31
30
  BirthWithContactTracing,
32
31
  Death,
33
32
  Event,
34
33
  EventType,
34
+ Feature,
35
35
  Migration,
36
36
  Mutation,
37
37
  Sampling,
38
38
  SamplingWithContactTracing,
39
+ Tree,
40
+ compute_mean_leaf_pairwise_distance,
41
+ compute_sackin_index,
42
+ dump_newick,
39
43
  generate_trees,
40
44
  get_BD_events,
41
45
  get_BDEI_events,
42
46
  get_BDSS_events,
43
47
  get_canonical_events,
44
48
  get_contact_tracing_events,
49
+ get_distance,
45
50
  get_epidemiological_events,
46
51
  get_FBD_events,
47
- simulate_tree,
48
- )
49
- from phylogenie.utils import (
50
- compute_colless_index,
51
- compute_mean_leaf_pairwise_distance,
52
- compute_sackin_index,
53
- get_distance,
54
52
  get_mrca,
53
+ get_mutation_id,
55
54
  get_node_depth_levels,
56
55
  get_node_depths,
57
56
  get_node_height_levels,
58
57
  get_node_heights,
59
58
  get_node_leaf_counts,
59
+ get_node_state,
60
+ load_newick,
61
+ load_nexus,
62
+ set_features,
63
+ simulate_tree,
60
64
  )
61
65
 
62
66
  __all__ = [
67
+ "Coloring",
68
+ "draw_tree",
63
69
  "AliSimDatasetGenerator",
64
70
  "BDEITreeDatasetGenerator",
65
71
  "BDSSTreeDatasetGenerator",
@@ -69,51 +75,52 @@ __all__ = [
69
75
  "DatasetGeneratorConfig",
70
76
  "EpidemiologicalTreeDatasetGenerator",
71
77
  "FBDTreeDatasetGenerator",
78
+ "TreeDatasetGeneratorConfig",
79
+ "load_fasta",
80
+ "MSA",
72
81
  "SkylineMatrix",
73
82
  "SkylineMatrixCoercible",
74
- "skyline_matrix",
75
83
  "SkylineParameter",
76
84
  "SkylineParameterLike",
77
- "skyline_parameter",
78
85
  "SkylineVector",
79
86
  "SkylineVectorCoercible",
80
87
  "SkylineVectorLike",
88
+ "skyline_matrix",
89
+ "skyline_parameter",
81
90
  "skyline_vector",
82
- "Tree",
83
- "TreeDatasetGeneratorConfig",
84
91
  "Birth",
85
92
  "BirthWithContactTracing",
86
93
  "Death",
87
94
  "Event",
88
95
  "EventType",
96
+ "Feature",
89
97
  "Migration",
90
98
  "Mutation",
91
99
  "Sampling",
92
100
  "SamplingWithContactTracing",
101
+ "Tree",
102
+ "compute_mean_leaf_pairwise_distance",
103
+ "compute_sackin_index",
104
+ "dump_newick",
105
+ "generate_trees",
93
106
  "get_BD_events",
94
107
  "get_BDEI_events",
95
108
  "get_BDSS_events",
96
109
  "get_canonical_events",
97
110
  "get_contact_tracing_events",
111
+ "get_distance",
98
112
  "get_epidemiological_events",
99
113
  "get_FBD_events",
100
- "generate_trees",
101
- "simulate_tree",
102
- "dump_newick",
103
- "load_nexus",
104
- "load_fasta",
105
- "load_newick",
106
- "MSA",
107
- "Coloring",
108
- "draw_tree",
109
- "compute_colless_index",
110
- "compute_mean_leaf_pairwise_distance",
111
- "compute_sackin_index",
112
- "get_distance",
113
114
  "get_mrca",
114
- "get_node_depths",
115
+ "get_mutation_id",
115
116
  "get_node_depth_levels",
116
- "get_node_heights",
117
+ "get_node_depths",
117
118
  "get_node_height_levels",
119
+ "get_node_heights",
118
120
  "get_node_leaf_counts",
121
+ "get_node_state",
122
+ "load_newick",
123
+ "load_nexus",
124
+ "set_features",
125
+ "simulate_tree",
119
126
  ]
@@ -1,15 +1,14 @@
1
1
  from enum import Enum
2
- from itertools import islice
3
- from typing import Any
2
+ from typing import Any, Callable
4
3
 
5
4
  import matplotlib.colors as mcolors
6
5
  import matplotlib.patches as mpatches
7
6
  import matplotlib.pyplot as plt
8
7
  from matplotlib.axes import Axes
8
+ from matplotlib.colors import Colormap
9
9
  from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
10
10
 
11
- from phylogenie.tree import Tree
12
- from phylogenie.utils import get_node_depth_levels, get_node_depths
11
+ from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
13
12
 
14
13
 
15
14
  class Coloring(str, Enum):
@@ -20,13 +19,18 @@ class Coloring(str, Enum):
20
19
  Color = str | tuple[float, float, float] | tuple[float, float, float, float]
21
20
 
22
21
 
23
- def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color]) -> Axes:
22
+ def draw_colored_tree(
23
+ tree: Tree, ax: Axes | None = None, colors: Color | dict[Tree, Color] = "black"
24
+ ) -> Axes:
25
+ if ax is None:
26
+ ax = plt.gca()
27
+
24
28
  if not isinstance(colors, dict):
25
29
  colors = {node: colors for node in tree}
26
30
 
27
31
  xs = (
28
32
  get_node_depth_levels(tree)
29
- if any(node.branch_length is None for node in islice(tree, 1, None))
33
+ if any(node.branch_length is None for node in tree.iter_descendants())
30
34
  else get_node_depths(tree)
31
35
  )
32
36
  ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
@@ -34,14 +38,14 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
34
38
  if node.is_internal():
35
39
  ys[node] = sum(ys[child] for child in node.children) / len(node.children)
36
40
 
41
+ if tree.branch_length is not None:
42
+ ax.hlines(y=ys[tree], xmin=0, xmax=xs[tree], color=colors[tree]) # pyright: ignore
37
43
  for node in tree:
38
44
  x1, y1 = xs[node], ys[node]
39
- if node.parent is None:
40
- ax.hlines(y=y1, xmin=0, xmax=x1, color=colors[node]) # pyright: ignore
41
- continue
42
- x0, y0 = xs[node.parent], ys[node.parent]
43
- ax.vlines(x=x0, ymin=y0, ymax=y1, color=colors[node]) # pyright: ignore
44
- ax.hlines(y=y1, xmin=x0, xmax=x1, color=colors[node]) # pyright: ignore
45
+ for child in node.children:
46
+ x2, y2 = xs[child], ys[child]
47
+ ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
48
+ ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
45
49
 
46
50
  ax.set_yticks([]) # pyright: ignore
47
51
  return ax
@@ -50,10 +54,10 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
50
54
  def draw_tree(
51
55
  tree: Tree,
52
56
  ax: Axes | None = None,
53
- color_by: str | None = None,
57
+ color_by: str | dict[str, Any] | None = None,
54
58
  coloring: str | Coloring | None = None,
55
59
  default_color: Color = "black",
56
- cmap: str | None = None,
60
+ colormap: str | Colormap | None = None,
57
61
  vmin: float | None = None,
58
62
  vmax: float | None = None,
59
63
  show_legend: bool = True,
@@ -67,35 +71,40 @@ def draw_tree(
67
71
  ax = plt.gca()
68
72
 
69
73
  if color_by is None:
70
- return _draw_colored_tree(tree, ax, colors=default_color)
74
+ return draw_colored_tree(tree, ax, colors=default_color)
71
75
 
72
- features = [node.get(color_by) for node in tree if color_by in node.features]
76
+ if isinstance(color_by, str):
77
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
78
+ else:
79
+ features = {node: color_by[node.name] for node in tree if node.name in color_by}
80
+ values = list(features.values())
73
81
 
74
82
  if coloring is None:
75
83
  coloring = (
76
84
  Coloring.CONTINUOUS
77
- if any(isinstance(f, float) for f in features)
85
+ if any(isinstance(f, float) for f in values)
78
86
  else Coloring.DISCRETE
79
87
  )
88
+ if colormap is None:
89
+ colormap = "tab20" if coloring == Coloring.DISCRETE else "viridis"
90
+ if isinstance(colormap, str):
91
+ colormap = plt.get_cmap(colormap)
92
+
93
+ def _get_colors(feature_map: Callable[[Any], Color]) -> dict[Tree, Color]:
94
+ return {
95
+ node: feature_map(features[node]) if node in features else default_color
96
+ for node in tree
97
+ }
80
98
 
81
99
  if coloring == Coloring.DISCRETE:
82
- if any(isinstance(f, float) for f in features):
100
+ if any(isinstance(f, float) for f in values):
83
101
  raise ValueError(
84
102
  "Discrete coloring selected but feature values are not all categorical."
85
103
  )
86
-
87
- colormap = plt.get_cmap("tab20" if cmap is None else cmap)
88
104
  feature_colors = {
89
- f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features))
90
- }
91
- colors = {
92
- node: (
93
- feature_colors[node.get(color_by)]
94
- if color_by in node.features
95
- else default_color
96
- )
97
- for node in tree
105
+ f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(values))
98
106
  }
107
+ colors = _get_colors(lambda f: feature_colors[f])
99
108
 
100
109
  if show_legend:
101
110
  legend_handles = [
@@ -105,27 +114,19 @@ def draw_tree(
105
114
  )
106
115
  for f in feature_colors
107
116
  ]
108
- if any(color_by not in node.features for node in tree):
117
+ if any(color_by not in node.metadata for node in tree):
109
118
  legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
110
119
  if legend_kwargs is None:
111
120
  legend_kwargs = {}
112
121
  ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
113
122
 
114
- return _draw_colored_tree(tree, ax, colors)
123
+ return draw_colored_tree(tree, ax, colors)
115
124
 
116
125
  if coloring == Coloring.CONTINUOUS:
117
- vmin = min(features) if vmin is None else vmin
118
- vmax = max(features) if vmax is None else vmax
126
+ vmin = min(values) if vmin is None else vmin
127
+ vmax = max(values) if vmax is None else vmax
119
128
  norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
120
- colormap = plt.get_cmap("viridis" if cmap is None else cmap)
121
- colors = {
122
- node: (
123
- colormap(norm(float(node.get(color_by))))
124
- if color_by in node.features
125
- else default_color
126
- )
127
- for node in tree
128
- }
129
+ colors = _get_colors(lambda f: colormap(norm(float(f))))
129
130
 
130
131
  if show_hist:
131
132
  default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
@@ -134,19 +135,17 @@ def draw_tree(
134
135
  hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
135
136
 
136
137
  hist_kwargs = {} if hist_kwargs is None else hist_kwargs
137
- _, bins, patches = hist_ax.hist(features, **hist_kwargs) # pyright: ignore
138
+ _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
138
139
 
139
- for patch, b0, b1 in zip( # pyright: ignore
140
- patches, bins[:-1], bins[1:] # pyright: ignore
141
- ):
140
+ for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
142
141
  midpoint = (b0 + b1) / 2 # pyright: ignore
143
142
  patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
144
- return _draw_colored_tree(tree, ax, colors), hist_ax # pyright: ignore
143
+ return draw_colored_tree(tree, ax, colors), hist_ax # pyright: ignore
145
144
 
146
145
  else:
147
146
  sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
148
147
  ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
149
- return _draw_colored_tree(tree, ax, colors)
148
+ return draw_colored_tree(tree, ax, colors)
150
149
 
151
150
  raise ValueError(
152
151
  f"Unknown coloring method: {coloring}. Choices are {list(Coloring)}."
@@ -8,8 +8,7 @@ from numpy.random import Generator, default_rng
8
8
  from phylogenie.generators.dataset import DatasetGenerator, DataType
9
9
  from phylogenie.generators.factories import data, string
10
10
  from phylogenie.generators.trees import TreeDatasetGeneratorConfig
11
- from phylogenie.io import dump_newick
12
- from phylogenie.utils import get_node_depths
11
+ from phylogenie.treesimulator import dump_newick, get_node_depths
13
12
 
14
13
  MSAS_DIRNAME = "MSAs"
15
14
  TREES_DIRNAME = "trees"
@@ -60,26 +59,26 @@ class AliSimDatasetGenerator(DatasetGenerator):
60
59
  tree_filename = f"{filename}.temp-tree"
61
60
  msa_filename = filename
62
61
 
63
- d: dict[str, Any] = {"file_id": Path(msa_filename).stem}
62
+ md: dict[str, Any] = {"file_id": Path(msa_filename).stem}
64
63
  rng = default_rng(seed)
65
64
  while True:
66
- d.update(data(context, rng))
65
+ md.update(data(context, rng))
67
66
  try:
68
- tree, metadata = self.trees.simulate_one(d, seed)
67
+ tree, metadata = self.trees.simulate_one(md, seed)
69
68
  break
70
69
  except TimeoutError:
71
70
  print(
72
71
  "Tree simulation timed out, retrying with different parameters..."
73
72
  )
74
- d.update(metadata)
73
+ md.update(metadata)
75
74
 
76
75
  times = get_node_depths(tree)
77
76
  for leaf in tree.get_leaves():
78
77
  leaf.name += f"|{times[leaf]}"
79
78
  dump_newick(tree, f"{tree_filename}.nwk")
80
79
 
81
- self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, d)
80
+ self._generate_one_from_tree(msa_filename, f"{tree_filename}.nwk", rng, md)
82
81
  if not self.keep_trees:
83
82
  os.remove(f"{tree_filename}.nwk")
84
83
 
85
- return d
84
+ return md
@@ -1,7 +1,29 @@
1
+ from typing import Any
2
+
3
+ from numpy.random import Generator
4
+ from pydantic import BaseModel, ConfigDict
5
+
1
6
  import phylogenie.typings as pgt
2
- from phylogenie.models import Distribution, StrictBaseModel
3
7
  from phylogenie.treesimulator import EventType
4
8
 
9
+
10
+ class StrictBaseModel(BaseModel):
11
+ model_config = ConfigDict(extra="forbid")
12
+
13
+
14
+ class Distribution(BaseModel):
15
+ type: str
16
+ model_config = ConfigDict(extra="allow")
17
+
18
+ @property
19
+ def args(self) -> dict[str, Any]:
20
+ assert self.model_extra is not None
21
+ return self.model_extra
22
+
23
+ def __call__(self, rng: Generator) -> Any:
24
+ return getattr(rng, self.type)(**self.args)
25
+
26
+
5
27
  Integer = str | int
6
28
  Scalar = str | pgt.Scalar
7
29
  ManyScalars = str | pgt.Many[Scalar]
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  from numpy.random import Generator, default_rng
9
9
  from tqdm import tqdm
10
10
 
11
- from phylogenie.models import Distribution, StrictBaseModel
11
+ from phylogenie.generators.configs import Distribution, StrictBaseModel
12
12
 
13
13
 
14
14
  class DataType(str, Enum):
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Any
2
+ from typing import Any, Callable
3
3
 
4
4
  import numpy as np
5
5
  from numpy.random import Generator
@@ -8,7 +8,6 @@ import phylogenie.generators.configs as cfg
8
8
  import phylogenie.generators.typeguards as ctg
9
9
  import phylogenie.typeguards as tg
10
10
  import phylogenie.typings as pgt
11
- from phylogenie.models import Distribution
12
11
  from phylogenie.skyline import (
13
12
  SkylineMatrix,
14
13
  SkylineMatrixCoercible,
@@ -24,11 +23,7 @@ def _eval_expression(expression: str, data: dict[str, Any]) -> Any:
24
23
  return np.array(
25
24
  eval(
26
25
  expression,
27
- {
28
- "__builtins__": __builtins__,
29
- "np": np,
30
- **{k: np.array(v) for k, v in data.items()},
31
- },
26
+ {"np": np, **{k: np.array(v) for k, v in data.items()}},
32
27
  )
33
28
  ).tolist()
34
29
 
@@ -214,12 +209,12 @@ def skyline_matrix(
214
209
  return SkylineMatrix(value=value, change_times=change_times)
215
210
 
216
211
 
217
- def distribution(x: Distribution, data: dict[str, Any]) -> Distribution:
212
+ def distribution(x: cfg.Distribution, data: dict[str, Any]) -> cfg.Distribution:
218
213
  args = x.args
219
214
  for arg_name, arg_value in args.items():
220
215
  if isinstance(arg_value, str):
221
216
  args[arg_name] = _eval_expression(arg_value, data)
222
- return Distribution(type=x.type, **args)
217
+ return cfg.Distribution(type=x.type, **args)
223
218
 
224
219
 
225
220
  def mutations(
@@ -227,11 +222,14 @@ def mutations(
227
222
  data: dict[str, Any],
228
223
  states: set[str],
229
224
  rates_to_log: list[EventType] | None,
225
+ rng: Generator,
230
226
  ) -> list[Mutation]:
231
227
  mutations: list[Mutation] = []
232
228
  for m in x:
233
229
  rate = skyline_parameter(m.rate, data)
234
- rate_scalers = {k: distribution(v, data) for k, v in m.rate_scalers.items()}
230
+ rate_scalers: dict[EventType, Callable[[], float]] = {
231
+ k: lambda: distribution(v, data)(rng) for k, v in m.rate_scalers.items()
232
+ }
235
233
  if m.state is None:
236
234
  mutations.extend(
237
235
  Mutation(s, rate, rate_scalers, rates_to_log) for s in states
@@ -241,11 +239,10 @@ def mutations(
241
239
  return mutations
242
240
 
243
241
 
244
- def data(context: dict[str, Distribution] | None, rng: Generator) -> dict[str, Any]:
242
+ def data(context: dict[str, cfg.Distribution] | None, rng: Generator) -> dict[str, Any]:
245
243
  if context is None:
246
244
  return {}
247
245
  data: dict[str, Any] = {}
248
246
  for k, v in context.items():
249
- dist = distribution(v, data)
250
- data[k] = np.array(getattr(rng, dist.type)(**dist.args)).tolist()
247
+ data[k] = np.array(distribution(v, data)(rng)).tolist()
251
248
  return data
@@ -1,13 +1,14 @@
1
1
  from abc import abstractmethod
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Annotated, Any, Literal
4
+ from typing import Annotated, Any, Callable, Literal
5
5
 
6
6
  import numpy as np
7
7
  from numpy.random import default_rng
8
8
  from pydantic import Field
9
9
 
10
10
  import phylogenie.generators.configs as cfg
11
+ from phylogenie.generators.configs import Distribution
11
12
  from phylogenie.generators.dataset import DatasetGenerator, DataType
12
13
  from phylogenie.generators.factories import (
13
14
  data,
@@ -18,13 +19,12 @@ from phylogenie.generators.factories import (
18
19
  skyline_parameter,
19
20
  skyline_vector,
20
21
  )
21
- from phylogenie.io import dump_newick
22
- from phylogenie.models import Distribution
23
- from phylogenie.tree import Tree
24
22
  from phylogenie.treesimulator import (
25
23
  Event,
26
24
  EventType,
27
25
  Feature,
26
+ Tree,
27
+ dump_newick,
28
28
  get_BD_events,
29
29
  get_BDEI_events,
30
30
  get_BDSS_events,
@@ -50,13 +50,13 @@ class TreeDatasetGenerator(DatasetGenerator):
50
50
  data_type: Literal[DataType.TREES] = DataType.TREES
51
51
  mutations: list[cfg.Mutation] = Field(default_factory=lambda: [])
52
52
  rates_to_log: list[EventType] | None = None
53
- min_tips: cfg.Integer = 1
54
- max_tips: cfg.Integer | None = None
53
+ n_tips: cfg.Integer | None = None
55
54
  max_time: cfg.Scalar = np.inf
56
55
  init_state: str | None = None
57
56
  sampling_probability_at_present: cfg.Scalar = 0.0
58
57
  timeout: float = np.inf
59
58
  node_features: list[Feature] | None = None
59
+ acceptance_criterion: str | None = None
60
60
 
61
61
  @abstractmethod
62
62
  def _get_events(self, data: dict[str, Any]) -> list[Event]: ...
@@ -71,11 +71,19 @@ class TreeDatasetGenerator(DatasetGenerator):
71
71
  )
72
72
  events = self._get_events(data)
73
73
  states = {e.state for e in events}
74
- events += mutations(self.mutations, data, states, self.rates_to_log)
74
+ events += mutations(
75
+ self.mutations, data, states, self.rates_to_log, default_rng(seed)
76
+ )
77
+ acceptance_criterion: None | Callable[[Tree], bool] = (
78
+ None
79
+ if self.acceptance_criterion is None
80
+ else lambda tree: eval(
81
+ self.acceptance_criterion, {}, {"tree": tree} # pyright: ignore
82
+ )
83
+ )
75
84
  return simulate_tree(
76
85
  events=events,
77
- min_tips=integer(self.min_tips, data),
78
- max_tips=None if self.max_tips is None else integer(self.max_tips, data),
86
+ n_tips=None if self.n_tips is None else integer(self.n_tips, data),
79
87
  max_time=scalar(self.max_time, data),
80
88
  init_state=init_state,
81
89
  sampling_probability_at_present=scalar(
@@ -83,6 +91,7 @@ class TreeDatasetGenerator(DatasetGenerator):
83
91
  ),
84
92
  seed=seed,
85
93
  timeout=self.timeout,
94
+ acceptance_criterion=acceptance_criterion,
86
95
  )
87
96
 
88
97
  def generate_one(
@@ -157,7 +166,7 @@ class FBDTreeDatasetGenerator(TreeDatasetGenerator):
157
166
  class ContactTracingTreeDatasetGenerator(TreeDatasetGenerator):
158
167
  max_notified_contacts: cfg.Integer = 1
159
168
  notification_probability: cfg.SkylineParameter = 0.0
160
- sampling_rate_after_notification: cfg.SkylineParameter = np.inf
169
+ sampling_rate_after_notification: cfg.SkylineParameter = 2**32
161
170
  samplable_states_after_notification: list[str] | None = None
162
171
 
163
172
  @abstractmethod
@@ -0,0 +1,3 @@
1
+ from phylogenie.io.fasta import load_fasta
2
+
3
+ __all__ = ["load_fasta"]