phylogenie 2.1.4__py3-none-any.whl → 3.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. phylogenie/__init__.py +60 -14
  2. phylogenie/draw.py +690 -0
  3. phylogenie/generators/alisim.py +12 -12
  4. phylogenie/generators/configs.py +26 -4
  5. phylogenie/generators/dataset.py +3 -3
  6. phylogenie/generators/factories.py +38 -12
  7. phylogenie/generators/trees.py +48 -47
  8. phylogenie/io/__init__.py +3 -0
  9. phylogenie/io/fasta.py +34 -0
  10. phylogenie/main.py +27 -10
  11. phylogenie/mixins.py +33 -0
  12. phylogenie/skyline/matrix.py +11 -7
  13. phylogenie/skyline/parameter.py +12 -4
  14. phylogenie/skyline/vector.py +12 -6
  15. phylogenie/treesimulator/__init__.py +36 -3
  16. phylogenie/treesimulator/events/__init__.py +5 -5
  17. phylogenie/treesimulator/events/base.py +39 -0
  18. phylogenie/treesimulator/events/contact_tracing.py +38 -23
  19. phylogenie/treesimulator/events/core.py +21 -12
  20. phylogenie/treesimulator/events/mutations.py +46 -46
  21. phylogenie/treesimulator/features.py +49 -0
  22. phylogenie/treesimulator/gillespie.py +59 -55
  23. phylogenie/treesimulator/io/__init__.py +4 -0
  24. phylogenie/treesimulator/io/newick.py +104 -0
  25. phylogenie/treesimulator/io/nexus.py +50 -0
  26. phylogenie/treesimulator/model.py +25 -49
  27. phylogenie/treesimulator/tree.py +196 -0
  28. phylogenie/treesimulator/utils.py +108 -0
  29. phylogenie/typings.py +3 -3
  30. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/METADATA +13 -15
  31. phylogenie-3.1.7.dist-info/RECORD +41 -0
  32. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/WHEEL +2 -1
  33. phylogenie-3.1.7.dist-info/entry_points.txt +2 -0
  34. phylogenie-3.1.7.dist-info/top_level.txt +1 -0
  35. phylogenie/io.py +0 -107
  36. phylogenie/tree.py +0 -92
  37. phylogenie/utils.py +0 -17
  38. phylogenie-2.1.4.dist-info/RECORD +0 -32
  39. phylogenie-2.1.4.dist-info/entry_points.txt +0 -3
  40. {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info/licenses}/LICENSE.txt +0 -0
phylogenie/draw.py ADDED
@@ -0,0 +1,690 @@
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+ from typing import Any, Literal, overload
4
+
5
+ import matplotlib.colors as mcolors
6
+ import matplotlib.dates as mdates
7
+ import matplotlib.patches as mpatches
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.axes import Axes
10
+ from matplotlib.colors import Colormap
11
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
12
+
13
+ from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
14
+
15
+
16
+ @dataclass
17
+ class CalibrationNode:
18
+ node: Tree
19
+ date: datetime
20
+
21
+
22
+ Color = str | tuple[float, float, float] | tuple[float, float, float, float]
23
+
24
+
25
+ def draw_tree(
26
+ tree: Tree,
27
+ ax: Axes | None = None,
28
+ colors: Color | dict[Tree, Color] = "black",
29
+ backward_time: bool = False,
30
+ ) -> Axes:
31
+ """
32
+ Draw a phylogenetic tree with colored branches.
33
+
34
+ Parameters
35
+ ----------
36
+ tree : Tree
37
+ The phylogenetic tree to draw.
38
+ ax : Axes | None, optional
39
+ The matplotlib Axes to draw on. If None, uses the current Axes.
40
+ colors : Color | dict[Tree, Color], optional
41
+ A single color for all branches or a dictionary mapping each node to a color.
42
+ backward_time : bool, optional
43
+ If True, the x-axis is inverted to represent time going backward.
44
+
45
+ Returns
46
+ -------
47
+ Axes
48
+ The Axes with the drawn tree.
49
+ """
50
+ if ax is None:
51
+ ax = plt.gca()
52
+
53
+ if not isinstance(colors, dict):
54
+ colors = {node: colors for node in tree}
55
+
56
+ xs = (
57
+ get_node_depth_levels(tree)
58
+ if any(node.branch_length is None for node in tree.iter_descendants())
59
+ else get_node_depths(tree)
60
+ )
61
+ if backward_time:
62
+ max_x = max(xs.values())
63
+ xs = {node: max_x - x for node, x in xs.items()}
64
+
65
+ ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
66
+ for node in tree.postorder_traversal():
67
+ if node.is_internal():
68
+ ys[node] = sum(ys[child] for child in node.children) / len(node.children)
69
+
70
+ if tree.branch_length is not None:
71
+ xmin = xs[tree] + tree.branch_length if backward_time else 0
72
+ ax.hlines(y=ys[tree], xmin=xmin, xmax=xs[tree], color=colors[tree]) # pyright: ignore
73
+ for node in tree:
74
+ x1, y1 = xs[node], ys[node]
75
+ for child in node.children:
76
+ x2, y2 = xs[child], ys[child]
77
+ ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
78
+ ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
79
+
80
+ if backward_time:
81
+ ax.invert_xaxis()
82
+
83
+ ax.set_yticks([]) # pyright: ignore
84
+ return ax
85
+
86
+
87
+ def _depth_to_date(
88
+ depth: float, calibration_nodes: tuple[CalibrationNode, CalibrationNode]
89
+ ) -> datetime:
90
+ """
91
+ Convert a depth value to a date using linear interpolation between two calibration nodes.
92
+
93
+ Parameters
94
+ ----------
95
+ depth : float
96
+ The depth value to convert.
97
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
98
+ Two calibration nodes defining the mapping from depth to date.
99
+
100
+ Returns
101
+ -------
102
+ datetime
103
+ The interpolated date corresponding to the given depth.
104
+ """
105
+ node1, node2 = calibration_nodes
106
+ depth1, depth2 = node1.node.depth, node2.node.depth
107
+ date1, date2 = node1.date, node2.date
108
+ return date1 + (depth - depth1) * (date2 - date1) / (depth2 - depth1)
109
+
110
+
111
+ def draw_dated_tree(
112
+ tree: Tree,
113
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
114
+ ax: Axes | None = None,
115
+ colors: Color | dict[Tree, Color] = "black",
116
+ ) -> Axes:
117
+ """
118
+ Draw a phylogenetic tree with branches positioned according to calibrated dates.
119
+
120
+ Parameters
121
+ ----------
122
+ tree : Tree
123
+ The phylogenetic tree to draw.
124
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
125
+ Two calibration nodes defining the mapping from depth to date.
126
+ ax : Axes | None, optional
127
+ The matplotlib Axes to draw on. If None, uses the current Axes.
128
+ colors : Color | dict[Tree, Color], optional
129
+ A single color for all branches or a dictionary mapping each node to a color.
130
+
131
+ Returns
132
+ -------
133
+ Axes
134
+ The Axes with the drawn dated tree.
135
+ """
136
+ if ax is None:
137
+ ax = plt.gca()
138
+
139
+ if not isinstance(colors, dict):
140
+ colors = {node: colors for node in tree}
141
+
142
+ xs = {
143
+ node: _depth_to_date(depth=depth, calibration_nodes=calibration_nodes)
144
+ for node, depth in get_node_depths(tree).items()
145
+ }
146
+
147
+ ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
148
+ for node in tree.postorder_traversal():
149
+ if node.is_internal():
150
+ ys[node] = sum(ys[child] for child in node.children) / len(node.children)
151
+
152
+ if tree.branch_length is not None:
153
+ origin_date = _depth_to_date(depth=0, calibration_nodes=calibration_nodes)
154
+ ax.hlines( # pyright: ignore
155
+ y=ys[tree],
156
+ xmin=mdates.date2num(origin_date), # pyright: ignore
157
+ xmax=mdates.date2num(xs[tree]), # pyright: ignore
158
+ color=colors[tree],
159
+ )
160
+ for node in tree:
161
+ x1, y1 = xs[node], ys[node]
162
+ for child in node.children:
163
+ x2, y2 = xs[child], ys[child]
164
+ ax.hlines( # pyright: ignore
165
+ y=y2,
166
+ xmin=mdates.date2num(x1), # pyright: ignore
167
+ xmax=mdates.date2num(x2), # pyright: ignore
168
+ color=colors[child],
169
+ )
170
+ ax.vlines(x=mdates.date2num(x1), ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
171
+
172
+ ax.xaxis.set_major_locator(mdates.AutoDateLocator())
173
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
174
+ ax.tick_params(axis="x", labelrotation=45) # pyright: ignore
175
+
176
+ ax.set_yticks([]) # pyright: ignore
177
+ return ax
178
+
179
+
180
+ def _init_colored_tree_categorical(
181
+ tree: Tree,
182
+ color_by: str,
183
+ ax: Axes | None = None,
184
+ default_color: Color = "black",
185
+ colormap: str | Colormap = "tab20",
186
+ show_legend: bool = True,
187
+ labels: dict[Any, str] | None = None,
188
+ legend_kwargs: dict[str, Any] | None = None,
189
+ ) -> tuple[Axes, dict[Tree, Color]]:
190
+ """
191
+ Initialize colors for drawing a tree based on categorical metadata.
192
+
193
+ Parameters
194
+ ----------
195
+ tree : Tree
196
+ The phylogenetic tree.
197
+ color_by : str
198
+ The metadata key to color branches by.
199
+ ax : Axes | None, optional
200
+ The matplotlib Axes to draw on. If None, uses the current Axes.
201
+ default_color : Color, optional
202
+ The color to use for nodes without the specified metadata.
203
+ colormap : str | Colormap, optional
204
+ The colormap to use for coloring categories. Defaults to 'tab20'.
205
+ show_legend : bool, optional
206
+ Whether to display a legend for the categories.
207
+ labels : dict[Any, str] | None, optional
208
+ A mapping from category values to labels for the legend.
209
+ legend_kwargs : dict[str, Any] | None, optional
210
+ Additional keyword arguments to pass to the legend.
211
+
212
+ Returns
213
+ -------
214
+ tuple[Axes, dict[Tree, Color]]
215
+ The Axes and a dictionary mapping each node to its assigned color.
216
+ """
217
+ if ax is None:
218
+ ax = plt.gca()
219
+
220
+ if isinstance(colormap, str):
221
+ colormap = plt.get_cmap(colormap)
222
+
223
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
224
+ feature_colors = {
225
+ f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
226
+ }
227
+ colors = {
228
+ node: feature_colors[features[node]] if node in features else default_color
229
+ for node in tree
230
+ }
231
+
232
+ if show_legend:
233
+ legend_handles = [
234
+ mpatches.Patch(
235
+ color=feature_colors[f],
236
+ label=str(f) if labels is None else labels[f],
237
+ )
238
+ for f in feature_colors
239
+ ]
240
+ if any(color_by not in node.metadata for node in tree):
241
+ legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
242
+ if legend_kwargs is None:
243
+ legend_kwargs = {}
244
+ ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
245
+
246
+ return ax, colors
247
+
248
+
249
+ def draw_colored_tree_categorical(
250
+ tree: Tree,
251
+ color_by: str,
252
+ ax: Axes | None = None,
253
+ backward_time: bool = False,
254
+ default_color: Color = "black",
255
+ colormap: str | Colormap = "tab20",
256
+ show_legend: bool = True,
257
+ labels: dict[Any, str] | None = None,
258
+ legend_kwargs: dict[str, Any] | None = None,
259
+ ):
260
+ """
261
+ Draw a phylogenetic tree with branches colored based on categorical metadata.
262
+
263
+ Parameters
264
+ ----------
265
+ tree : Tree
266
+ The phylogenetic tree to draw.
267
+ color_by : str
268
+ The metadata key to color branches by.
269
+ ax : Axes | None, optional
270
+ The matplotlib Axes to draw on. If None, uses the current Axes.
271
+ backward_time : bool, optional
272
+ If True, the x-axis is inverted to represent time going backward.
273
+ default_color : Color, optional
274
+ The color to use for nodes without the specified metadata.
275
+ colormap : str | Colormap, optional
276
+ The colormap to use for coloring categories. Defaults to 'tab20'.
277
+ show_legend : bool, optional
278
+ Whether to display a legend for the categories.
279
+ labels : dict[Any, str] | None, optional
280
+ A mapping from category values to labels for the legend.
281
+ legend_kwargs : dict[str, Any] | None, optional
282
+ Additional keyword arguments to pass to the legend.
283
+
284
+ Returns
285
+ -------
286
+ Axes
287
+ The Axes with the drawn colored tree.
288
+ """
289
+ ax, colors = _init_colored_tree_categorical(
290
+ tree=tree,
291
+ color_by=color_by,
292
+ ax=ax,
293
+ default_color=default_color,
294
+ colormap=colormap,
295
+ show_legend=show_legend,
296
+ labels=labels,
297
+ legend_kwargs=legend_kwargs,
298
+ )
299
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
300
+
301
+
302
+ def draw_colored_dated_tree_categorical(
303
+ tree: Tree,
304
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
305
+ color_by: str,
306
+ ax: Axes | None = None,
307
+ default_color: Color = "black",
308
+ colormap: str | Colormap = "tab20",
309
+ show_legend: bool = True,
310
+ labels: dict[Any, str] | None = None,
311
+ legend_kwargs: dict[str, Any] | None = None,
312
+ ) -> Axes:
313
+ """
314
+ Draw a dated phylogenetic tree with branches colored based on categorical metadata.
315
+
316
+ Parameters
317
+ ----------
318
+ tree : Tree
319
+ The phylogenetic tree to draw.
320
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
321
+ Two calibration nodes defining the mapping from depth to date.
322
+ color_by : str
323
+ The metadata key to color branches by.
324
+ ax : Axes | None, optional
325
+ The matplotlib Axes to draw on. If None, uses the current Axes.
326
+ default_color : Color, optional
327
+ The color to use for nodes without the specified metadata.
328
+ colormap : str | Colormap, optional
329
+ The colormap to use for coloring categories. Defaults to 'tab20'.
330
+ show_legend : bool, optional
331
+ Whether to display a legend for the categories.
332
+ labels : dict[Any, str] | None, optional
333
+ A mapping from category values to labels for the legend.
334
+ legend_kwargs : dict[str, Any] | None, optional
335
+ Additional keyword arguments to pass to the legend.
336
+
337
+ Returns
338
+ -------
339
+ Axes
340
+ The Axes with the drawn colored dated tree.
341
+ """
342
+ ax, colors = _init_colored_tree_categorical(
343
+ tree=tree,
344
+ color_by=color_by,
345
+ ax=ax,
346
+ default_color=default_color,
347
+ colormap=colormap,
348
+ show_legend=show_legend,
349
+ labels=labels,
350
+ legend_kwargs=legend_kwargs,
351
+ )
352
+ return draw_dated_tree(
353
+ tree=tree, calibration_nodes=calibration_nodes, ax=ax, colors=colors
354
+ )
355
+
356
+
357
+ @overload
358
+ def _init_colored_tree_continuous(
359
+ tree: Tree,
360
+ color_by: str,
361
+ ax: Axes | None = ...,
362
+ default_color: Color = ...,
363
+ colormap: str | Colormap = ...,
364
+ vmin: float | None = ...,
365
+ vmax: float | None = ...,
366
+ *,
367
+ show_hist: Literal[False],
368
+ hist_kwargs: dict[str, Any] | None = ...,
369
+ hist_axes_kwargs: dict[str, Any] | None = ...,
370
+ ) -> tuple[Axes, dict[Tree, Color]]: ...
371
+ @overload
372
+ def _init_colored_tree_continuous(
373
+ tree: Tree,
374
+ color_by: str,
375
+ ax: Axes | None = ...,
376
+ default_color: Color = ...,
377
+ colormap: str | Colormap = ...,
378
+ vmin: float | None = ...,
379
+ vmax: float | None = ...,
380
+ *,
381
+ show_hist: Literal[True],
382
+ hist_kwargs: dict[str, Any] | None = ...,
383
+ hist_axes_kwargs: dict[str, Any] | None = ...,
384
+ ) -> tuple[Axes, dict[Tree, Color], Axes]: ...
385
+ def _init_colored_tree_continuous(
386
+ tree: Tree,
387
+ color_by: str,
388
+ ax: Axes | None = None,
389
+ default_color: Color = "black",
390
+ colormap: str | Colormap = "viridis",
391
+ vmin: float | None = None,
392
+ vmax: float | None = None,
393
+ *,
394
+ show_hist: bool = True,
395
+ hist_kwargs: dict[str, Any] | None = None,
396
+ hist_axes_kwargs: dict[str, Any] | None = None,
397
+ ) -> tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]:
398
+ """
399
+ Initialize colors for drawing a tree based on continuous metadata.
400
+
401
+ Parameters
402
+ ----------
403
+ tree : Tree
404
+ The phylogenetic tree.
405
+ color_by : str
406
+ The metadata key to color branches by.
407
+ ax : Axes | None, optional
408
+ The matplotlib Axes to draw on. If None, uses the current Axes.
409
+ default_color : Color, optional
410
+ The color to use for nodes without the specified metadata.
411
+ colormap : str | Colormap, optional
412
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
413
+ vmin : float | None, optional
414
+ The minimum value for normalization. If None, uses the minimum of the data.
415
+ vmax : float | None, optional
416
+ The maximum value for normalization. If None, uses the maximum of the data.
417
+ show_hist : bool, optional
418
+ Whether to display a histogram of the continuous values.
419
+ hist_kwargs : dict[str, Any] | None, optional
420
+ Additional keyword arguments to pass to the histogram.
421
+ hist_axes_kwargs : dict[str, Any] | None, optional
422
+ Additional keyword arguments to define the histogram Axes.
423
+
424
+ Returns
425
+ -------
426
+ tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]
427
+ The Axes, a dictionary mapping each node to its assigned color,
428
+ and optionally the histogram Axes if `show_hist` is True.
429
+ """
430
+ if ax is None:
431
+ ax = plt.gca()
432
+
433
+ if isinstance(colormap, str):
434
+ colormap = plt.get_cmap(colormap)
435
+
436
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
437
+ values = list(features.values())
438
+ vmin = min(values) if vmin is None else vmin
439
+ vmax = max(values) if vmax is None else vmax
440
+ norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
441
+ colors = {
442
+ node: colormap(norm(float(features[node])))
443
+ if node in features
444
+ else default_color
445
+ for node in tree
446
+ }
447
+
448
+ if show_hist:
449
+ default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
450
+ if hist_axes_kwargs is not None:
451
+ default_hist_axes_kwargs.update(hist_axes_kwargs)
452
+ hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
453
+
454
+ hist_kwargs = {} if hist_kwargs is None else hist_kwargs
455
+ _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
456
+
457
+ for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
458
+ midpoint = (b0 + b1) / 2 # pyright: ignore
459
+ patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
460
+ return ax, colors, hist_ax # pyright: ignore
461
+
462
+ sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
463
+ ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
464
+ return ax, colors
465
+
466
+
467
+ @overload
468
+ def draw_colored_tree_continuous(
469
+ tree: Tree,
470
+ color_by: str,
471
+ ax: Axes | None = None,
472
+ backward_time: bool = False,
473
+ default_color: Color = "black",
474
+ colormap: str | Colormap = "viridis",
475
+ vmin: float | None = None,
476
+ vmax: float | None = None,
477
+ *,
478
+ show_hist: Literal[False],
479
+ hist_kwargs: dict[str, Any] | None = None,
480
+ hist_axes_kwargs: dict[str, Any] | None = None,
481
+ ) -> Axes: ...
482
+ @overload
483
+ def draw_colored_tree_continuous(
484
+ tree: Tree,
485
+ color_by: str,
486
+ ax: Axes | None = None,
487
+ backward_time: bool = False,
488
+ default_color: Color = "black",
489
+ colormap: str | Colormap = "viridis",
490
+ vmin: float | None = None,
491
+ vmax: float | None = None,
492
+ *,
493
+ show_hist: Literal[True],
494
+ hist_kwargs: dict[str, Any] | None = None,
495
+ hist_axes_kwargs: dict[str, Any] | None = None,
496
+ ) -> tuple[Axes, Axes]: ...
497
+ def draw_colored_tree_continuous(
498
+ tree: Tree,
499
+ color_by: str,
500
+ ax: Axes | None = None,
501
+ backward_time: bool = False,
502
+ default_color: Color = "black",
503
+ colormap: str | Colormap = "viridis",
504
+ vmin: float | None = None,
505
+ vmax: float | None = None,
506
+ show_hist: bool = True,
507
+ hist_kwargs: dict[str, Any] | None = None,
508
+ hist_axes_kwargs: dict[str, Any] | None = None,
509
+ ) -> Axes | tuple[Axes, Axes]:
510
+ """
511
+ Draw a phylogenetic tree with branches colored based on continuous metadata.
512
+
513
+ Parameters
514
+ ----------
515
+ tree : Tree
516
+ The phylogenetic tree to draw.
517
+ color_by : str
518
+ The metadata key to color branches by.
519
+ ax : Axes | None, optional
520
+ The matplotlib Axes to draw on. If None, uses the current Axes.
521
+ backward_time : bool, optional
522
+ If True, the x-axis is inverted to represent time going backward.
523
+ default_color : Color, optional
524
+ The color to use for nodes without the specified metadata.
525
+ colormap : str | Colormap, optional
526
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
527
+ vmin : float | None, optional
528
+ The minimum value for normalization. If None, uses the minimum of the data.
529
+ vmax : float | None, optional
530
+ The maximum value for normalization. If None, uses the maximum of the data.
531
+ show_hist : bool, optional
532
+ Whether to display a histogram of the continuous values.
533
+ hist_kwargs : dict[str, Any] | None, optional
534
+ Additional keyword arguments to pass to the histogram.
535
+ hist_axes_kwargs : dict[str, Any] | None, optional
536
+ Additional keyword arguments to define the histogram Axes.
537
+
538
+ Returns
539
+ -------
540
+ Axes | tuple[Axes, Axes]
541
+ The Axes with the drawn colored tree,
542
+ and optionally the histogram Axes if `show_hist` is True.
543
+ """
544
+ if show_hist:
545
+ ax, colors, hist_ax = _init_colored_tree_continuous(
546
+ tree=tree,
547
+ color_by=color_by,
548
+ ax=ax,
549
+ default_color=default_color,
550
+ colormap=colormap,
551
+ vmin=vmin,
552
+ vmax=vmax,
553
+ show_hist=show_hist,
554
+ hist_kwargs=hist_kwargs,
555
+ hist_axes_kwargs=hist_axes_kwargs,
556
+ )
557
+ return draw_tree(
558
+ tree=tree, ax=ax, colors=colors, backward_time=backward_time
559
+ ), hist_ax
560
+
561
+ ax, colors = _init_colored_tree_continuous(
562
+ tree=tree,
563
+ color_by=color_by,
564
+ ax=ax,
565
+ default_color=default_color,
566
+ colormap=colormap,
567
+ vmin=vmin,
568
+ vmax=vmax,
569
+ show_hist=show_hist,
570
+ hist_kwargs=hist_kwargs,
571
+ hist_axes_kwargs=hist_axes_kwargs,
572
+ )
573
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
574
+
575
+
576
+ @overload
577
+ def draw_colored_dated_tree_continuous(
578
+ tree: Tree,
579
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
580
+ color_by: str,
581
+ ax: Axes | None = None,
582
+ default_color: Color = "black",
583
+ colormap: str | Colormap = "viridis",
584
+ vmin: float | None = None,
585
+ vmax: float | None = None,
586
+ *,
587
+ show_hist: Literal[False],
588
+ hist_kwargs: dict[str, Any] | None = None,
589
+ hist_axes_kwargs: dict[str, Any] | None = None,
590
+ ) -> Axes: ...
591
+ @overload
592
+ def draw_colored_dated_tree_continuous(
593
+ tree: Tree,
594
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
595
+ color_by: str,
596
+ ax: Axes | None = None,
597
+ default_color: Color = "black",
598
+ colormap: str | Colormap = "viridis",
599
+ vmin: float | None = None,
600
+ vmax: float | None = None,
601
+ *,
602
+ show_hist: Literal[True],
603
+ hist_kwargs: dict[str, Any] | None = None,
604
+ hist_axes_kwargs: dict[str, Any] | None = None,
605
+ ) -> tuple[Axes, Axes]: ...
606
+ def draw_colored_dated_tree_continuous(
607
+ tree: Tree,
608
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
609
+ color_by: str,
610
+ ax: Axes | None = None,
611
+ default_color: Color = "black",
612
+ colormap: str | Colormap = "viridis",
613
+ vmin: float | None = None,
614
+ vmax: float | None = None,
615
+ show_hist: bool = True,
616
+ hist_kwargs: dict[str, Any] | None = None,
617
+ hist_axes_kwargs: dict[str, Any] | None = None,
618
+ ) -> Axes | tuple[Axes, Axes]:
619
+ """
620
+ Draw a dated phylogenetic tree with branches colored based on continuous metadata.
621
+
622
+ Parameters
623
+ ----------
624
+ tree : Tree
625
+ The phylogenetic tree to draw.
626
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
627
+ Two calibration nodes defining the mapping from depth to date.
628
+ color_by : str
629
+ The metadata key to color branches by.
630
+ ax : Axes | None, optional
631
+ The matplotlib Axes to draw on. If None, uses the current Axes.
632
+ default_color : Color, optional
633
+ The color to use for nodes without the specified metadata.
634
+ colormap : str | Colormap, optional
635
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
636
+ vmin : float | None, optional
637
+ The minimum value for normalization. If None, uses the minimum of the data.
638
+ vmax : float | None, optional
639
+ The maximum value for normalization. If None, uses the maximum of the data.
640
+ show_hist : bool, optional
641
+ Whether to display a histogram of the continuous values.
642
+ hist_kwargs : dict[str, Any] | None, optional
643
+ Additional keyword arguments to pass to the histogram.
644
+ hist_axes_kwargs : dict[str, Any] | None, optional
645
+ Additional keyword arguments to define the histogram Axes.
646
+
647
+ Returns
648
+ -------
649
+ Axes | tuple[Axes, Axes]
650
+ The Axes with the drawn colored dated tree,
651
+ and optionally the histogram Axes if `show_hist` is True.
652
+ """
653
+ if show_hist:
654
+ ax, colors, hist_ax = _init_colored_tree_continuous(
655
+ tree=tree,
656
+ color_by=color_by,
657
+ ax=ax,
658
+ default_color=default_color,
659
+ colormap=colormap,
660
+ vmin=vmin,
661
+ vmax=vmax,
662
+ show_hist=show_hist,
663
+ hist_kwargs=hist_kwargs,
664
+ hist_axes_kwargs=hist_axes_kwargs,
665
+ )
666
+ return draw_dated_tree(
667
+ tree=tree,
668
+ calibration_nodes=calibration_nodes,
669
+ ax=ax,
670
+ colors=colors,
671
+ ), hist_ax
672
+
673
+ ax, colors = _init_colored_tree_continuous(
674
+ tree=tree,
675
+ color_by=color_by,
676
+ ax=ax,
677
+ default_color=default_color,
678
+ colormap=colormap,
679
+ vmin=vmin,
680
+ vmax=vmax,
681
+ show_hist=show_hist,
682
+ hist_kwargs=hist_kwargs,
683
+ hist_axes_kwargs=hist_axes_kwargs,
684
+ )
685
+ return draw_dated_tree(
686
+ tree=tree,
687
+ calibration_nodes=calibration_nodes,
688
+ ax=ax,
689
+ colors=colors,
690
+ )