phylogenie 3.1.5__tar.gz → 3.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {phylogenie-3.1.5/src/phylogenie.egg-info → phylogenie-3.1.9}/PKG-INFO +1 -1
  2. {phylogenie-3.1.5 → phylogenie-3.1.9}/pyproject.toml +1 -1
  3. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/__init__.py +17 -4
  4. phylogenie-3.1.9/src/phylogenie/draw.py +691 -0
  5. phylogenie-3.1.9/src/phylogenie/io/__init__.py +3 -0
  6. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/io/fasta.py +17 -3
  7. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/mixins.py +3 -11
  8. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/msa.py +4 -3
  9. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/__init__.py +6 -0
  10. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/features.py +6 -0
  11. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/io/nexus.py +2 -2
  12. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/tree.py +15 -2
  13. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/utils.py +11 -0
  14. {phylogenie-3.1.5 → phylogenie-3.1.9/src/phylogenie.egg-info}/PKG-INFO +1 -1
  15. phylogenie-3.1.5/src/phylogenie/draw.py +0 -152
  16. phylogenie-3.1.5/src/phylogenie/io/__init__.py +0 -3
  17. {phylogenie-3.1.5 → phylogenie-3.1.9}/LICENSE.txt +0 -0
  18. {phylogenie-3.1.5 → phylogenie-3.1.9}/README.md +0 -0
  19. {phylogenie-3.1.5 → phylogenie-3.1.9}/setup.cfg +0 -0
  20. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/__init__.py +0 -0
  21. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/alisim.py +0 -0
  22. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/configs.py +0 -0
  23. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/dataset.py +0 -0
  24. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/factories.py +0 -0
  25. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/trees.py +0 -0
  26. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/generators/typeguards.py +0 -0
  27. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/main.py +0 -0
  28. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/py.typed +0 -0
  29. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/skyline/__init__.py +0 -0
  30. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/skyline/matrix.py +0 -0
  31. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/skyline/parameter.py +0 -0
  32. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/skyline/vector.py +0 -0
  33. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/events/__init__.py +0 -0
  34. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/events/base.py +0 -0
  35. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/events/contact_tracing.py +0 -0
  36. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/events/core.py +0 -0
  37. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/events/mutations.py +0 -0
  38. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/gillespie.py +0 -0
  39. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/io/__init__.py +0 -0
  40. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/io/newick.py +0 -0
  41. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/treesimulator/model.py +0 -0
  42. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/typeguards.py +0 -0
  43. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie/typings.py +0 -0
  44. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie.egg-info/SOURCES.txt +0 -0
  45. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie.egg-info/dependency_links.txt +0 -0
  46. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie.egg-info/entry_points.txt +0 -0
  47. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie.egg-info/requires.txt +0 -0
  48. {phylogenie-3.1.5 → phylogenie-3.1.9}/src/phylogenie.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: phylogenie
3
- Version: 3.1.5
3
+ Version: 3.1.9
4
4
  Summary: Generate phylogenetic datasets with minimal setup effort
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "phylogenie"
3
- version = "3.1.5"
3
+ version = "3.1.9"
4
4
  description = "Generate phylogenetic datasets with minimal setup effort"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -1,4 +1,11 @@
1
- from phylogenie.draw import Coloring, draw_tree
1
+ from phylogenie.draw import (
2
+ draw_colored_dated_tree_categorical,
3
+ draw_colored_dated_tree_continuous,
4
+ draw_colored_tree_categorical,
5
+ draw_colored_tree_continuous,
6
+ draw_dated_tree,
7
+ draw_tree,
8
+ )
2
9
  from phylogenie.generators import (
3
10
  AliSimDatasetGenerator,
4
11
  BDEITreeDatasetGenerator,
@@ -11,8 +18,8 @@ from phylogenie.generators import (
11
18
  FBDTreeDatasetGenerator,
12
19
  TreeDatasetGeneratorConfig,
13
20
  )
14
- from phylogenie.io import load_fasta
15
- from phylogenie.msa import MSA
21
+ from phylogenie.io import dump_fasta, load_fasta
22
+ from phylogenie.msa import MSA, Sequence
16
23
  from phylogenie.skyline import (
17
24
  SkylineMatrix,
18
25
  SkylineMatrixCoercible,
@@ -64,7 +71,11 @@ from phylogenie.treesimulator import (
64
71
  )
65
72
 
66
73
  __all__ = [
67
- "Coloring",
74
+ "draw_colored_dated_tree_categorical",
75
+ "draw_colored_dated_tree_continuous",
76
+ "draw_colored_tree_categorical",
77
+ "draw_colored_tree_continuous",
78
+ "draw_dated_tree",
68
79
  "draw_tree",
69
80
  "AliSimDatasetGenerator",
70
81
  "BDEITreeDatasetGenerator",
@@ -76,8 +87,10 @@ __all__ = [
76
87
  "EpidemiologicalTreeDatasetGenerator",
77
88
  "FBDTreeDatasetGenerator",
78
89
  "TreeDatasetGeneratorConfig",
90
+ "dump_fasta",
79
91
  "load_fasta",
80
92
  "MSA",
93
+ "Sequence",
81
94
  "SkylineMatrix",
82
95
  "SkylineMatrixCoercible",
83
96
  "SkylineParameter",
@@ -0,0 +1,691 @@
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+ from typing import Any, Literal, overload
4
+
5
+ import matplotlib.colors as mcolors
6
+ import matplotlib.dates as mdates
7
+ import matplotlib.patches as mpatches
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.axes import Axes
10
+ from matplotlib.colors import Colormap
11
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
12
+
13
+ from phylogenie.treesimulator import (
14
+ Tree,
15
+ get_node_ages,
16
+ get_node_depth_levels,
17
+ get_node_depths,
18
+ )
19
+
20
+
21
+ @dataclass
22
+ class CalibrationNode:
23
+ node: Tree
24
+ date: datetime
25
+
26
+
27
+ Color = str | tuple[float, float, float] | tuple[float, float, float, float]
28
+
29
+
30
+ def draw_tree(
31
+ tree: Tree,
32
+ ax: Axes | None = None,
33
+ colors: Color | dict[Tree, Color] = "black",
34
+ backward_time: bool = False,
35
+ ) -> Axes:
36
+ """
37
+ Draw a phylogenetic tree with colored branches.
38
+
39
+ Parameters
40
+ ----------
41
+ tree : Tree
42
+ The phylogenetic tree to draw.
43
+ ax : Axes | None, optional
44
+ The matplotlib Axes to draw on. If None, uses the current Axes.
45
+ colors : Color | dict[Tree, Color], optional
46
+ A single color for all branches or a dictionary mapping each node to a color.
47
+ backward_time : bool, optional
48
+ If True, the x-axis is inverted to represent time going backward.
49
+
50
+ Returns
51
+ -------
52
+ Axes
53
+ The Axes with the drawn tree.
54
+ """
55
+ if ax is None:
56
+ ax = plt.gca()
57
+
58
+ if not isinstance(colors, dict):
59
+ colors = {node: colors for node in tree}
60
+
61
+ xs = (
62
+ get_node_ages(tree)
63
+ if backward_time
64
+ else get_node_depth_levels(tree)
65
+ if any(node.branch_length is None for node in tree.iter_descendants())
66
+ else get_node_depths(tree)
67
+ )
68
+
69
+ ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
70
+ for node in tree.postorder_traversal():
71
+ if node.is_internal():
72
+ ys[node] = sum(ys[child] for child in node.children) / len(node.children)
73
+
74
+ if tree.branch_length is not None:
75
+ xmin = xs[tree] + tree.branch_length if backward_time else 0
76
+ ax.hlines(y=ys[tree], xmin=xmin, xmax=xs[tree], color=colors[tree]) # pyright: ignore
77
+ for node in tree:
78
+ x1, y1 = xs[node], ys[node]
79
+ for child in node.children:
80
+ x2, y2 = xs[child], ys[child]
81
+ ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
82
+ ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
83
+
84
+ if backward_time:
85
+ ax.invert_xaxis()
86
+
87
+ ax.set_yticks([]) # pyright: ignore
88
+ return ax
89
+
90
+
91
+ def _depth_to_date(
92
+ depth: float, calibration_nodes: tuple[CalibrationNode, CalibrationNode]
93
+ ) -> datetime:
94
+ """
95
+ Convert a depth value to a date using linear interpolation between two calibration nodes.
96
+
97
+ Parameters
98
+ ----------
99
+ depth : float
100
+ The depth value to convert.
101
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
102
+ Two calibration nodes defining the mapping from depth to date.
103
+
104
+ Returns
105
+ -------
106
+ datetime
107
+ The interpolated date corresponding to the given depth.
108
+ """
109
+ node1, node2 = calibration_nodes
110
+ depth1, depth2 = node1.node.depth, node2.node.depth
111
+ date1, date2 = node1.date, node2.date
112
+ return date1 + (depth - depth1) * (date2 - date1) / (depth2 - depth1)
113
+
114
+
115
+ def draw_dated_tree(
116
+ tree: Tree,
117
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
118
+ ax: Axes | None = None,
119
+ colors: Color | dict[Tree, Color] = "black",
120
+ ) -> Axes:
121
+ """
122
+ Draw a phylogenetic tree with branches positioned according to calibrated dates.
123
+
124
+ Parameters
125
+ ----------
126
+ tree : Tree
127
+ The phylogenetic tree to draw.
128
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
129
+ Two calibration nodes defining the mapping from depth to date.
130
+ ax : Axes | None, optional
131
+ The matplotlib Axes to draw on. If None, uses the current Axes.
132
+ colors : Color | dict[Tree, Color], optional
133
+ A single color for all branches or a dictionary mapping each node to a color.
134
+
135
+ Returns
136
+ -------
137
+ Axes
138
+ The Axes with the drawn dated tree.
139
+ """
140
+ if ax is None:
141
+ ax = plt.gca()
142
+
143
+ if not isinstance(colors, dict):
144
+ colors = {node: colors for node in tree}
145
+
146
+ xs = {
147
+ node: _depth_to_date(depth=depth, calibration_nodes=calibration_nodes)
148
+ for node, depth in get_node_depths(tree).items()
149
+ }
150
+
151
+ ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
152
+ for node in tree.postorder_traversal():
153
+ if node.is_internal():
154
+ ys[node] = sum(ys[child] for child in node.children) / len(node.children)
155
+
156
+ if tree.branch_length is not None:
157
+ origin_date = _depth_to_date(depth=0, calibration_nodes=calibration_nodes)
158
+ ax.hlines( # pyright: ignore
159
+ y=ys[tree],
160
+ xmin=mdates.date2num(origin_date), # pyright: ignore
161
+ xmax=mdates.date2num(xs[tree]), # pyright: ignore
162
+ color=colors[tree],
163
+ )
164
+ for node in tree:
165
+ x1, y1 = xs[node], ys[node]
166
+ for child in node.children:
167
+ x2, y2 = xs[child], ys[child]
168
+ ax.hlines( # pyright: ignore
169
+ y=y2,
170
+ xmin=mdates.date2num(x1), # pyright: ignore
171
+ xmax=mdates.date2num(x2), # pyright: ignore
172
+ color=colors[child],
173
+ )
174
+ ax.vlines(x=mdates.date2num(x1), ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
175
+
176
+ ax.xaxis.set_major_locator(mdates.AutoDateLocator())
177
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
178
+ ax.tick_params(axis="x", labelrotation=45) # pyright: ignore
179
+
180
+ ax.set_yticks([]) # pyright: ignore
181
+ return ax
182
+
183
+
184
+ def _init_colored_tree_categorical(
185
+ tree: Tree,
186
+ color_by: str,
187
+ ax: Axes | None = None,
188
+ default_color: Color = "black",
189
+ colormap: str | Colormap = "tab20",
190
+ show_legend: bool = True,
191
+ labels: dict[Any, str] | None = None,
192
+ legend_kwargs: dict[str, Any] | None = None,
193
+ ) -> tuple[Axes, dict[Tree, Color]]:
194
+ """
195
+ Initialize colors for drawing a tree based on categorical metadata.
196
+
197
+ Parameters
198
+ ----------
199
+ tree : Tree
200
+ The phylogenetic tree.
201
+ color_by : str
202
+ The metadata key to color branches by.
203
+ ax : Axes | None, optional
204
+ The matplotlib Axes to draw on. If None, uses the current Axes.
205
+ default_color : Color, optional
206
+ The color to use for nodes without the specified metadata.
207
+ colormap : str | Colormap, optional
208
+ The colormap to use for coloring categories. Defaults to 'tab20'.
209
+ show_legend : bool, optional
210
+ Whether to display a legend for the categories.
211
+ labels : dict[Any, str] | None, optional
212
+ A mapping from category values to labels for the legend.
213
+ legend_kwargs : dict[str, Any] | None, optional
214
+ Additional keyword arguments to pass to the legend.
215
+
216
+ Returns
217
+ -------
218
+ tuple[Axes, dict[Tree, Color]]
219
+ The Axes and a dictionary mapping each node to its assigned color.
220
+ """
221
+ if ax is None:
222
+ ax = plt.gca()
223
+
224
+ if isinstance(colormap, str):
225
+ colormap = plt.get_cmap(colormap)
226
+
227
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
228
+ feature_colors = {
229
+ f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
230
+ }
231
+ colors = {
232
+ node: feature_colors[features[node]] if node in features else default_color
233
+ for node in tree
234
+ }
235
+
236
+ if show_legend:
237
+ legend_handles = [
238
+ mpatches.Patch(
239
+ color=feature_colors[f],
240
+ label=str(f) if labels is None else labels[f],
241
+ )
242
+ for f in feature_colors
243
+ ]
244
+ if any(color_by not in node.metadata for node in tree):
245
+ legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
246
+ if legend_kwargs is None:
247
+ legend_kwargs = {}
248
+ ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
249
+
250
+ return ax, colors
251
+
252
+
253
+ def draw_colored_tree_categorical(
254
+ tree: Tree,
255
+ color_by: str,
256
+ ax: Axes | None = None,
257
+ backward_time: bool = False,
258
+ default_color: Color = "black",
259
+ colormap: str | Colormap = "tab20",
260
+ show_legend: bool = True,
261
+ labels: dict[Any, str] | None = None,
262
+ legend_kwargs: dict[str, Any] | None = None,
263
+ ):
264
+ """
265
+ Draw a phylogenetic tree with branches colored based on categorical metadata.
266
+
267
+ Parameters
268
+ ----------
269
+ tree : Tree
270
+ The phylogenetic tree to draw.
271
+ color_by : str
272
+ The metadata key to color branches by.
273
+ ax : Axes | None, optional
274
+ The matplotlib Axes to draw on. If None, uses the current Axes.
275
+ backward_time : bool, optional
276
+ If True, the x-axis is inverted to represent time going backward.
277
+ default_color : Color, optional
278
+ The color to use for nodes without the specified metadata.
279
+ colormap : str | Colormap, optional
280
+ The colormap to use for coloring categories. Defaults to 'tab20'.
281
+ show_legend : bool, optional
282
+ Whether to display a legend for the categories.
283
+ labels : dict[Any, str] | None, optional
284
+ A mapping from category values to labels for the legend.
285
+ legend_kwargs : dict[str, Any] | None, optional
286
+ Additional keyword arguments to pass to the legend.
287
+
288
+ Returns
289
+ -------
290
+ Axes
291
+ The Axes with the drawn colored tree.
292
+ """
293
+ ax, colors = _init_colored_tree_categorical(
294
+ tree=tree,
295
+ color_by=color_by,
296
+ ax=ax,
297
+ default_color=default_color,
298
+ colormap=colormap,
299
+ show_legend=show_legend,
300
+ labels=labels,
301
+ legend_kwargs=legend_kwargs,
302
+ )
303
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
304
+
305
+
306
+ def draw_colored_dated_tree_categorical(
307
+ tree: Tree,
308
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
309
+ color_by: str,
310
+ ax: Axes | None = None,
311
+ default_color: Color = "black",
312
+ colormap: str | Colormap = "tab20",
313
+ show_legend: bool = True,
314
+ labels: dict[Any, str] | None = None,
315
+ legend_kwargs: dict[str, Any] | None = None,
316
+ ) -> Axes:
317
+ """
318
+ Draw a dated phylogenetic tree with branches colored based on categorical metadata.
319
+
320
+ Parameters
321
+ ----------
322
+ tree : Tree
323
+ The phylogenetic tree to draw.
324
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
325
+ Two calibration nodes defining the mapping from depth to date.
326
+ color_by : str
327
+ The metadata key to color branches by.
328
+ ax : Axes | None, optional
329
+ The matplotlib Axes to draw on. If None, uses the current Axes.
330
+ default_color : Color, optional
331
+ The color to use for nodes without the specified metadata.
332
+ colormap : str | Colormap, optional
333
+ The colormap to use for coloring categories. Defaults to 'tab20'.
334
+ show_legend : bool, optional
335
+ Whether to display a legend for the categories.
336
+ labels : dict[Any, str] | None, optional
337
+ A mapping from category values to labels for the legend.
338
+ legend_kwargs : dict[str, Any] | None, optional
339
+ Additional keyword arguments to pass to the legend.
340
+
341
+ Returns
342
+ -------
343
+ Axes
344
+ The Axes with the drawn colored dated tree.
345
+ """
346
+ ax, colors = _init_colored_tree_categorical(
347
+ tree=tree,
348
+ color_by=color_by,
349
+ ax=ax,
350
+ default_color=default_color,
351
+ colormap=colormap,
352
+ show_legend=show_legend,
353
+ labels=labels,
354
+ legend_kwargs=legend_kwargs,
355
+ )
356
+ return draw_dated_tree(
357
+ tree=tree, calibration_nodes=calibration_nodes, ax=ax, colors=colors
358
+ )
359
+
360
+
361
+ @overload
362
+ def _init_colored_tree_continuous(
363
+ tree: Tree,
364
+ color_by: str,
365
+ ax: Axes | None = ...,
366
+ default_color: Color = ...,
367
+ colormap: str | Colormap = ...,
368
+ vmin: float | None = ...,
369
+ vmax: float | None = ...,
370
+ *,
371
+ show_hist: Literal[False],
372
+ hist_kwargs: dict[str, Any] | None = ...,
373
+ hist_axes_kwargs: dict[str, Any] | None = ...,
374
+ ) -> tuple[Axes, dict[Tree, Color]]: ...
375
+ @overload
376
+ def _init_colored_tree_continuous(
377
+ tree: Tree,
378
+ color_by: str,
379
+ ax: Axes | None = ...,
380
+ default_color: Color = ...,
381
+ colormap: str | Colormap = ...,
382
+ vmin: float | None = ...,
383
+ vmax: float | None = ...,
384
+ *,
385
+ show_hist: Literal[True] = True,
386
+ hist_kwargs: dict[str, Any] | None = ...,
387
+ hist_axes_kwargs: dict[str, Any] | None = ...,
388
+ ) -> tuple[Axes, dict[Tree, Color], Axes]: ...
389
+ def _init_colored_tree_continuous(
390
+ tree: Tree,
391
+ color_by: str,
392
+ ax: Axes | None = None,
393
+ default_color: Color = "black",
394
+ colormap: str | Colormap = "viridis",
395
+ vmin: float | None = None,
396
+ vmax: float | None = None,
397
+ show_hist: bool = True,
398
+ hist_kwargs: dict[str, Any] | None = None,
399
+ hist_axes_kwargs: dict[str, Any] | None = None,
400
+ ) -> tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]:
401
+ """
402
+ Initialize colors for drawing a tree based on continuous metadata.
403
+
404
+ Parameters
405
+ ----------
406
+ tree : Tree
407
+ The phylogenetic tree.
408
+ color_by : str
409
+ The metadata key to color branches by.
410
+ ax : Axes | None, optional
411
+ The matplotlib Axes to draw on. If None, uses the current Axes.
412
+ default_color : Color, optional
413
+ The color to use for nodes without the specified metadata.
414
+ colormap : str | Colormap, optional
415
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
416
+ vmin : float | None, optional
417
+ The minimum value for normalization. If None, uses the minimum of the data.
418
+ vmax : float | None, optional
419
+ The maximum value for normalization. If None, uses the maximum of the data.
420
+ show_hist : bool, optional
421
+ Whether to display a histogram of the continuous values.
422
+ hist_kwargs : dict[str, Any] | None, optional
423
+ Additional keyword arguments to pass to the histogram.
424
+ hist_axes_kwargs : dict[str, Any] | None, optional
425
+ Additional keyword arguments to define the histogram Axes.
426
+
427
+ Returns
428
+ -------
429
+ tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]
430
+ The Axes, a dictionary mapping each node to its assigned color,
431
+ and optionally the histogram Axes if `show_hist` is True.
432
+ """
433
+ if ax is None:
434
+ ax = plt.gca()
435
+
436
+ if isinstance(colormap, str):
437
+ colormap = plt.get_cmap(colormap)
438
+
439
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
440
+ values = list(features.values())
441
+ vmin = min(values) if vmin is None else vmin
442
+ vmax = max(values) if vmax is None else vmax
443
+ norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
444
+ colors = {
445
+ node: colormap(norm(float(features[node])))
446
+ if node in features
447
+ else default_color
448
+ for node in tree
449
+ }
450
+
451
+ if show_hist:
452
+ default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
453
+ if hist_axes_kwargs is not None:
454
+ default_hist_axes_kwargs.update(hist_axes_kwargs)
455
+ hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
456
+
457
+ hist_kwargs = {} if hist_kwargs is None else hist_kwargs
458
+ _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
459
+
460
+ for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
461
+ midpoint = (b0 + b1) / 2 # pyright: ignore
462
+ patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
463
+ return ax, colors, hist_ax # pyright: ignore
464
+
465
+ sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
466
+ ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
467
+ return ax, colors
468
+
469
+
470
+ @overload
471
+ def draw_colored_tree_continuous(
472
+ tree: Tree,
473
+ color_by: str,
474
+ ax: Axes | None = None,
475
+ backward_time: bool = False,
476
+ default_color: Color = "black",
477
+ colormap: str | Colormap = "viridis",
478
+ vmin: float | None = None,
479
+ vmax: float | None = None,
480
+ *,
481
+ show_hist: Literal[False],
482
+ hist_kwargs: dict[str, Any] | None = None,
483
+ hist_axes_kwargs: dict[str, Any] | None = None,
484
+ ) -> Axes: ...
485
+ @overload
486
+ def draw_colored_tree_continuous(
487
+ tree: Tree,
488
+ color_by: str,
489
+ ax: Axes | None = None,
490
+ backward_time: bool = False,
491
+ default_color: Color = "black",
492
+ colormap: str | Colormap = "viridis",
493
+ vmin: float | None = None,
494
+ vmax: float | None = None,
495
+ show_hist: Literal[True] = True,
496
+ hist_kwargs: dict[str, Any] | None = None,
497
+ hist_axes_kwargs: dict[str, Any] | None = None,
498
+ ) -> tuple[Axes, Axes]: ...
499
+ def draw_colored_tree_continuous(
500
+ tree: Tree,
501
+ color_by: str,
502
+ ax: Axes | None = None,
503
+ backward_time: bool = False,
504
+ default_color: Color = "black",
505
+ colormap: str | Colormap = "viridis",
506
+ vmin: float | None = None,
507
+ vmax: float | None = None,
508
+ show_hist: bool = True,
509
+ hist_kwargs: dict[str, Any] | None = None,
510
+ hist_axes_kwargs: dict[str, Any] | None = None,
511
+ ) -> Axes | tuple[Axes, Axes]:
512
+ """
513
+ Draw a phylogenetic tree with branches colored based on continuous metadata.
514
+
515
+ Parameters
516
+ ----------
517
+ tree : Tree
518
+ The phylogenetic tree to draw.
519
+ color_by : str
520
+ The metadata key to color branches by.
521
+ ax : Axes | None, optional
522
+ The matplotlib Axes to draw on. If None, uses the current Axes.
523
+ backward_time : bool, optional
524
+ If True, the x-axis is inverted to represent time going backward.
525
+ default_color : Color, optional
526
+ The color to use for nodes without the specified metadata.
527
+ colormap : str | Colormap, optional
528
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
529
+ vmin : float | None, optional
530
+ The minimum value for normalization. If None, uses the minimum of the data.
531
+ vmax : float | None, optional
532
+ The maximum value for normalization. If None, uses the maximum of the data.
533
+ show_hist : bool, optional
534
+ Whether to display a histogram of the continuous values.
535
+ hist_kwargs : dict[str, Any] | None, optional
536
+ Additional keyword arguments to pass to the histogram.
537
+ hist_axes_kwargs : dict[str, Any] | None, optional
538
+ Additional keyword arguments to define the histogram Axes.
539
+
540
+ Returns
541
+ -------
542
+ Axes | tuple[Axes, Axes]
543
+ The Axes with the drawn colored tree,
544
+ and optionally the histogram Axes if `show_hist` is True.
545
+ """
546
+ if show_hist:
547
+ ax, colors, hist_ax = _init_colored_tree_continuous(
548
+ tree=tree,
549
+ color_by=color_by,
550
+ ax=ax,
551
+ default_color=default_color,
552
+ colormap=colormap,
553
+ vmin=vmin,
554
+ vmax=vmax,
555
+ show_hist=show_hist,
556
+ hist_kwargs=hist_kwargs,
557
+ hist_axes_kwargs=hist_axes_kwargs,
558
+ )
559
+ return draw_tree(
560
+ tree=tree, ax=ax, colors=colors, backward_time=backward_time
561
+ ), hist_ax
562
+
563
+ ax, colors = _init_colored_tree_continuous(
564
+ tree=tree,
565
+ color_by=color_by,
566
+ ax=ax,
567
+ default_color=default_color,
568
+ colormap=colormap,
569
+ vmin=vmin,
570
+ vmax=vmax,
571
+ show_hist=show_hist,
572
+ hist_kwargs=hist_kwargs,
573
+ hist_axes_kwargs=hist_axes_kwargs,
574
+ )
575
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
576
+
577
+
578
+ @overload
579
+ def draw_colored_dated_tree_continuous(
580
+ tree: Tree,
581
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
582
+ color_by: str,
583
+ ax: Axes | None = None,
584
+ default_color: Color = "black",
585
+ colormap: str | Colormap = "viridis",
586
+ vmin: float | None = None,
587
+ vmax: float | None = None,
588
+ *,
589
+ show_hist: Literal[False],
590
+ hist_kwargs: dict[str, Any] | None = None,
591
+ hist_axes_kwargs: dict[str, Any] | None = None,
592
+ ) -> Axes: ...
593
+ @overload
594
+ def draw_colored_dated_tree_continuous(
595
+ tree: Tree,
596
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
597
+ color_by: str,
598
+ ax: Axes | None = None,
599
+ default_color: Color = "black",
600
+ colormap: str | Colormap = "viridis",
601
+ vmin: float | None = None,
602
+ vmax: float | None = None,
603
+ show_hist: Literal[True] = True,
604
+ hist_kwargs: dict[str, Any] | None = None,
605
+ hist_axes_kwargs: dict[str, Any] | None = None,
606
+ ) -> tuple[Axes, Axes]: ...
607
+ def draw_colored_dated_tree_continuous(
608
+ tree: Tree,
609
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
610
+ color_by: str,
611
+ ax: Axes | None = None,
612
+ default_color: Color = "black",
613
+ colormap: str | Colormap = "viridis",
614
+ vmin: float | None = None,
615
+ vmax: float | None = None,
616
+ show_hist: bool = True,
617
+ hist_kwargs: dict[str, Any] | None = None,
618
+ hist_axes_kwargs: dict[str, Any] | None = None,
619
+ ) -> Axes | tuple[Axes, Axes]:
620
+ """
621
+ Draw a dated phylogenetic tree with branches colored based on continuous metadata.
622
+
623
+ Parameters
624
+ ----------
625
+ tree : Tree
626
+ The phylogenetic tree to draw.
627
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
628
+ Two calibration nodes defining the mapping from depth to date.
629
+ color_by : str
630
+ The metadata key to color branches by.
631
+ ax : Axes | None, optional
632
+ The matplotlib Axes to draw on. If None, uses the current Axes.
633
+ default_color : Color, optional
634
+ The color to use for nodes without the specified metadata.
635
+ colormap : str | Colormap, optional
636
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
637
+ vmin : float | None, optional
638
+ The minimum value for normalization. If None, uses the minimum of the data.
639
+ vmax : float | None, optional
640
+ The maximum value for normalization. If None, uses the maximum of the data.
641
+ show_hist : bool, optional
642
+ Whether to display a histogram of the continuous values.
643
+ hist_kwargs : dict[str, Any] | None, optional
644
+ Additional keyword arguments to pass to the histogram.
645
+ hist_axes_kwargs : dict[str, Any] | None, optional
646
+ Additional keyword arguments to define the histogram Axes.
647
+
648
+ Returns
649
+ -------
650
+ Axes | tuple[Axes, Axes]
651
+ The Axes with the drawn colored dated tree,
652
+ and optionally the histogram Axes if `show_hist` is True.
653
+ """
654
+ if show_hist:
655
+ ax, colors, hist_ax = _init_colored_tree_continuous(
656
+ tree=tree,
657
+ color_by=color_by,
658
+ ax=ax,
659
+ default_color=default_color,
660
+ colormap=colormap,
661
+ vmin=vmin,
662
+ vmax=vmax,
663
+ show_hist=show_hist,
664
+ hist_kwargs=hist_kwargs,
665
+ hist_axes_kwargs=hist_axes_kwargs,
666
+ )
667
+ return draw_dated_tree(
668
+ tree=tree,
669
+ calibration_nodes=calibration_nodes,
670
+ ax=ax,
671
+ colors=colors,
672
+ ), hist_ax
673
+
674
+ ax, colors = _init_colored_tree_continuous(
675
+ tree=tree,
676
+ color_by=color_by,
677
+ ax=ax,
678
+ default_color=default_color,
679
+ colormap=colormap,
680
+ vmin=vmin,
681
+ vmax=vmax,
682
+ show_hist=show_hist,
683
+ hist_kwargs=hist_kwargs,
684
+ hist_axes_kwargs=hist_axes_kwargs,
685
+ )
686
+ return draw_dated_tree(
687
+ tree=tree,
688
+ calibration_nodes=calibration_nodes,
689
+ ax=ax,
690
+ colors=colors,
691
+ )
@@ -0,0 +1,3 @@
1
+ from phylogenie.io.fasta import dump_fasta, load_fasta
2
+
3
+ __all__ = ["load_fasta", "dump_fasta"]
@@ -1,3 +1,4 @@
1
+ from datetime import date
1
2
  from pathlib import Path
2
3
  from typing import Callable
3
4
 
@@ -5,7 +6,8 @@ from phylogenie.msa import MSA, Sequence
5
6
 
6
7
 
7
8
  def load_fasta(
8
- fasta_file: str | Path, extract_time_from_id: Callable[[str], float] | None = None
9
+ fasta_file: str | Path,
10
+ extract_time_from_id: Callable[[str], float | date] | None = None,
9
11
  ) -> MSA:
10
12
  sequences: list[Sequence] = []
11
13
  with open(fasta_file, "r") as f:
@@ -17,10 +19,22 @@ def load_fasta(
17
19
  if extract_time_from_id is not None:
18
20
  time = extract_time_from_id(id)
19
21
  elif "|" in id:
22
+ last_metadata = id.split("|")[-1]
20
23
  try:
21
- time = float(id.split("|")[-1])
24
+ time = float(last_metadata)
22
25
  except ValueError:
23
- pass
26
+ try:
27
+ time = date.fromisoformat(last_metadata)
28
+ except ValueError:
29
+ pass
24
30
  chars = next(f).strip()
25
31
  sequences.append(Sequence(id, chars, time))
26
32
  return MSA(sequences)
33
+
34
+
35
+ def dump_fasta(msa: MSA | list[Sequence], fasta_file: str | Path) -> None:
36
+ with open(fasta_file, "w") as f:
37
+ sequences = msa.sequences if isinstance(msa, MSA) else msa
38
+ for seq in sequences:
39
+ f.write(f">{seq.id}\n")
40
+ f.write(f"{seq.chars}\n")
@@ -1,39 +1,31 @@
1
+ from collections.abc import Mapping
1
2
  from types import MappingProxyType
2
- from typing import Any, Mapping, Optional
3
+ from typing import Any
3
4
 
4
5
 
5
6
  class MetadataMixin:
6
- """A mixin that provides metadata management with dictionary-like access."""
7
-
8
7
  def __init__(self) -> None:
9
8
  self._metadata: dict[str, Any] = {}
10
9
 
11
10
  @property
12
11
  def metadata(self) -> Mapping[str, Any]:
13
- """Return a read-only view of all metadata."""
14
12
  return MappingProxyType(self._metadata)
15
13
 
16
14
  def set(self, key: str, value: Any) -> None:
17
- """Set or update a metadata value."""
18
15
  self._metadata[key] = value
19
16
 
20
17
  def update(self, metadata: Mapping[str, Any]) -> None:
21
- """Bulk update metadata values."""
22
18
  self._metadata.update(metadata)
23
19
 
24
- def get(self, key: str, default: Optional[Any] = None) -> Any:
25
- """Get a metadata value, returning `default` if not found."""
20
+ def get(self, key: str, default: Any = None) -> Any:
26
21
  return self._metadata.get(key, default)
27
22
 
28
23
  def delete(self, key: str) -> None:
29
- """Delete a metadata if it exists, else do nothing."""
30
24
  self._metadata.pop(key, None)
31
25
 
32
26
  def clear(self) -> None:
33
- """Remove all metadata."""
34
27
  self._metadata.clear()
35
28
 
36
- # Dict-like behavior
37
29
  def __getitem__(self, key: str) -> Any:
38
30
  return self._metadata[key]
39
31
 
@@ -1,5 +1,6 @@
1
1
  from collections.abc import Iterator
2
2
  from dataclasses import dataclass
3
+ from datetime import date
3
4
 
4
5
  import numpy as np
5
6
 
@@ -8,7 +9,7 @@ import numpy as np
8
9
  class Sequence:
9
10
  id: str
10
11
  chars: str
11
- time: float | None = None
12
+ time: float | date | None = None
12
13
 
13
14
 
14
15
  class MSA:
@@ -25,8 +26,8 @@ class MSA:
25
26
  return [sequence.id for sequence in self.sequences]
26
27
 
27
28
  @property
28
- def times(self) -> list[float]:
29
- times: list[float] = []
29
+ def times(self) -> list[float | date]:
30
+ times: list[float | date] = []
30
31
  for sequence in self:
31
32
  if sequence.time is None:
32
33
  raise ValueError(f"Time is not set for sequence {sequence.id}.")
@@ -27,11 +27,14 @@ from phylogenie.treesimulator.utils import (
27
27
  compute_sackin_index,
28
28
  get_distance,
29
29
  get_mrca,
30
+ get_node_ages,
30
31
  get_node_depth_levels,
31
32
  get_node_depths,
32
33
  get_node_height_levels,
33
34
  get_node_heights,
34
35
  get_node_leaf_counts,
36
+ get_node_times,
37
+ get_path,
35
38
  )
36
39
 
37
40
  __all__ = [
@@ -65,9 +68,12 @@ __all__ = [
65
68
  "compute_sackin_index",
66
69
  "get_distance",
67
70
  "get_mrca",
71
+ "get_node_ages",
68
72
  "get_node_depth_levels",
69
73
  "get_node_depths",
70
74
  "get_node_height_levels",
71
75
  "get_node_heights",
72
76
  "get_node_leaf_counts",
77
+ "get_node_times",
78
+ "get_path",
73
79
  ]
@@ -5,11 +5,13 @@ from phylogenie.treesimulator.events.mutations import get_mutation_id
5
5
  from phylogenie.treesimulator.model import get_node_state
6
6
  from phylogenie.treesimulator.tree import Tree
7
7
  from phylogenie.treesimulator.utils import (
8
+ get_node_ages,
8
9
  get_node_depth_levels,
9
10
  get_node_depths,
10
11
  get_node_height_levels,
11
12
  get_node_heights,
12
13
  get_node_leaf_counts,
14
+ get_node_times,
13
15
  )
14
16
 
15
17
 
@@ -22,6 +24,7 @@ def _get_mutations(tree: Tree) -> dict[Tree, int]:
22
24
 
23
25
 
24
26
  class Feature(str, Enum):
27
+ AGE = "age"
25
28
  DEPTH = "depth"
26
29
  DEPTH_LEVEL = "depth_level"
27
30
  HEIGHT = "height"
@@ -29,9 +32,11 @@ class Feature(str, Enum):
29
32
  MUTATION = "mutation"
30
33
  N_LEAVES = "n_leaves"
31
34
  STATE = "state"
35
+ TIME = "time"
32
36
 
33
37
 
34
38
  FEATURES_EXTRACTORS = {
39
+ Feature.AGE: get_node_ages,
35
40
  Feature.DEPTH: get_node_depths,
36
41
  Feature.DEPTH_LEVEL: get_node_depth_levels,
37
42
  Feature.HEIGHT: get_node_heights,
@@ -39,6 +44,7 @@ FEATURES_EXTRACTORS = {
39
44
  Feature.MUTATION: _get_mutations,
40
45
  Feature.N_LEAVES: get_node_leaf_counts,
41
46
  Feature.STATE: _get_states,
47
+ Feature.TIME: get_node_times,
42
48
  }
43
49
 
44
50
 
@@ -15,7 +15,7 @@ def _parse_translate_block(lines: Iterator[str]) -> dict[str, str]:
15
15
  if ";" in line:
16
16
  return translations
17
17
  else:
18
- raise ValueError(f"Invalid translate line. Expected '<num> <name>'.")
18
+ raise ValueError("Invalid translate line. Expected '<num> <name>'.")
19
19
  translations[match.group(1)] = match.group(2)
20
20
  raise ValueError("Translate block not terminated with ';'.")
21
21
 
@@ -33,7 +33,7 @@ def _parse_trees_block(lines: Iterator[str]) -> dict[str, Tree]:
33
33
  match = re.match(r"^TREE\s*\*?\s+(\S+)\s*=\s*(.+)$", line, re.IGNORECASE)
34
34
  if match is None:
35
35
  raise ValueError(
36
- f"Invalid tree line. Expected 'TREE <name> = <newick>'."
36
+ "Invalid tree line. Expected 'TREE <name> = <newick>'."
37
37
  )
38
38
  name = match.group(1)
39
39
  if name in trees:
@@ -153,14 +153,27 @@ class Tree(MetadataMixin):
153
153
  child.branch_length_or_raise() + child.height for child in self.children
154
154
  )
155
155
 
156
+ @property
157
+ def time(self) -> float:
158
+ return self.depth
159
+
160
+ @property
161
+ def age(self) -> float:
162
+ if self.parent is None:
163
+ return self.height
164
+ return self.parent.age - self.branch_length_or_raise()
165
+
156
166
  # -------------
157
167
  # Miscellaneous
158
168
  # -------------
159
169
  # Other useful miscellaneous methods.
160
170
 
161
171
  def ladderize(self, key: Callable[["Tree"], Any] | None = None) -> None:
172
+ def _default_key(node: Tree) -> int:
173
+ return node.n_leaves
174
+
162
175
  if key is None:
163
- key = lambda node: node.n_leaves
176
+ key = _default_key
164
177
  self._children.sort(key=key)
165
178
  for child in self.children:
166
179
  child.ladderize(key)
@@ -169,7 +182,7 @@ class Tree(MetadataMixin):
169
182
  for node in self:
170
183
  if node.name == name:
171
184
  return node
172
- raise ValueError(f"Node with name {name} not found.")
185
+ raise ValueError(f"Node {name} not found.")
173
186
 
174
187
  def copy(self):
175
188
  new_tree = Tree(self.name, self.branch_length)
@@ -51,6 +51,17 @@ def get_node_heights(tree: Tree) -> dict[Tree, float]:
51
51
  return heights
52
52
 
53
53
 
54
+ def get_node_times(tree: Tree) -> dict[Tree, float]:
55
+ return get_node_depths(tree)
56
+
57
+
58
+ def get_node_ages(tree: Tree) -> dict[Tree, float]:
59
+ ages: dict[Tree, float] = {tree: tree.height}
60
+ for node in tree.iter_descendants():
61
+ ages[node] = ages[node.parent] - node.branch_length # pyright: ignore
62
+ return ages
63
+
64
+
54
65
  def get_mrca(node1: Tree, node2: Tree) -> Tree:
55
66
  node1_ancestors = set(node1.iter_upward())
56
67
  for node2_ancestor in node2.iter_upward():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: phylogenie
3
- Version: 3.1.5
3
+ Version: 3.1.9
4
4
  Summary: Generate phylogenetic datasets with minimal setup effort
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,152 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Callable
3
-
4
- import matplotlib.colors as mcolors
5
- import matplotlib.patches as mpatches
6
- import matplotlib.pyplot as plt
7
- from matplotlib.axes import Axes
8
- from matplotlib.colors import Colormap
9
- from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
10
-
11
- from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
12
-
13
-
14
- class Coloring(str, Enum):
15
- DISCRETE = "discrete"
16
- CONTINUOUS = "continuous"
17
-
18
-
19
- Color = str | tuple[float, float, float] | tuple[float, float, float, float]
20
-
21
-
22
- def draw_colored_tree(
23
- tree: Tree, ax: Axes | None = None, colors: Color | dict[Tree, Color] = "black"
24
- ) -> Axes:
25
- if ax is None:
26
- ax = plt.gca()
27
-
28
- if not isinstance(colors, dict):
29
- colors = {node: colors for node in tree}
30
-
31
- xs = (
32
- get_node_depth_levels(tree)
33
- if any(node.branch_length is None for node in tree.iter_descendants())
34
- else get_node_depths(tree)
35
- )
36
- ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
37
- for node in tree.postorder_traversal():
38
- if node.is_internal():
39
- ys[node] = sum(ys[child] for child in node.children) / len(node.children)
40
-
41
- if tree.branch_length is not None:
42
- ax.hlines(y=ys[tree], xmin=0, xmax=xs[tree], color=colors[tree]) # pyright: ignore
43
- for node in tree:
44
- x1, y1 = xs[node], ys[node]
45
- for child in node.children:
46
- x2, y2 = xs[child], ys[child]
47
- ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
48
- ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
49
-
50
- ax.set_yticks([]) # pyright: ignore
51
- return ax
52
-
53
-
54
- def draw_tree(
55
- tree: Tree,
56
- ax: Axes | None = None,
57
- color_by: str | dict[str, Any] | None = None,
58
- coloring: str | Coloring | None = None,
59
- default_color: Color = "black",
60
- colormap: str | Colormap | None = None,
61
- vmin: float | None = None,
62
- vmax: float | None = None,
63
- show_legend: bool = True,
64
- labels: dict[Any, Any] | None = None,
65
- legend_kwargs: dict[str, Any] | None = None,
66
- show_hist: bool = True,
67
- hist_kwargs: dict[str, Any] | None = None,
68
- hist_axes_kwargs: dict[str, Any] | None = None,
69
- ) -> Axes | tuple[Axes, Axes]:
70
- if ax is None:
71
- ax = plt.gca()
72
-
73
- if color_by is None:
74
- return draw_colored_tree(tree, ax, colors=default_color)
75
-
76
- if isinstance(color_by, str):
77
- features = {node: node[color_by] for node in tree if color_by in node.metadata}
78
- else:
79
- features = {node: color_by[node.name] for node in tree if node.name in color_by}
80
- values = list(features.values())
81
-
82
- if coloring is None:
83
- coloring = (
84
- Coloring.CONTINUOUS
85
- if any(isinstance(f, float) for f in values)
86
- else Coloring.DISCRETE
87
- )
88
- if colormap is None:
89
- colormap = "tab20" if coloring == Coloring.DISCRETE else "viridis"
90
- if isinstance(colormap, str):
91
- colormap = plt.get_cmap(colormap)
92
-
93
- def _get_colors(feature_map: Callable[[Any], Color]) -> dict[Tree, Color]:
94
- return {
95
- node: feature_map(features[node]) if node in features else default_color
96
- for node in tree
97
- }
98
-
99
- if coloring == Coloring.DISCRETE:
100
- if any(isinstance(f, float) for f in values):
101
- raise ValueError(
102
- "Discrete coloring selected but feature values are not all categorical."
103
- )
104
- feature_colors = {
105
- f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(values))
106
- }
107
- colors = _get_colors(lambda f: feature_colors[f])
108
-
109
- if show_legend:
110
- legend_handles = [
111
- mpatches.Patch(
112
- color=feature_colors[f],
113
- label=str(f) if labels is None else labels[f],
114
- )
115
- for f in feature_colors
116
- ]
117
- if any(color_by not in node.metadata for node in tree):
118
- legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
119
- if legend_kwargs is None:
120
- legend_kwargs = {}
121
- ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
122
-
123
- return draw_colored_tree(tree, ax, colors)
124
-
125
- if coloring == Coloring.CONTINUOUS:
126
- vmin = min(values) if vmin is None else vmin
127
- vmax = max(values) if vmax is None else vmax
128
- norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
129
- colors = _get_colors(lambda f: colormap(norm(float(f))))
130
-
131
- if show_hist:
132
- default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
133
- if hist_axes_kwargs is not None:
134
- default_hist_axes_kwargs.update(hist_axes_kwargs)
135
- hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
136
-
137
- hist_kwargs = {} if hist_kwargs is None else hist_kwargs
138
- _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
139
-
140
- for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
141
- midpoint = (b0 + b1) / 2 # pyright: ignore
142
- patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
143
- return draw_colored_tree(tree, ax, colors), hist_ax # pyright: ignore
144
-
145
- else:
146
- sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
147
- ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
148
- return draw_colored_tree(tree, ax, colors)
149
-
150
- raise ValueError(
151
- f"Unknown coloring method: {coloring}. Choices are {list(Coloring)}."
152
- )
@@ -1,3 +0,0 @@
1
- from phylogenie.io.fasta import load_fasta
2
-
3
- __all__ = ["load_fasta"]
File without changes
File without changes
File without changes