phylogenie 3.1.4__py3-none-any.whl → 3.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
phylogenie/draw.py CHANGED
@@ -1,7 +1,9 @@
1
- from enum import Enum
2
- from typing import Any, Callable
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+ from typing import Any, Literal, overload
3
4
 
4
5
  import matplotlib.colors as mcolors
6
+ import matplotlib.dates as mdates
5
7
  import matplotlib.patches as mpatches
6
8
  import matplotlib.pyplot as plt
7
9
  from matplotlib.axes import Axes
@@ -11,15 +13,43 @@ from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
11
13
  from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
12
14
 
13
15
 
14
- class Coloring(str, Enum):
15
- DISCRETE = "discrete"
16
- CONTINUOUS = "continuous"
16
+ @dataclass
17
+ class CalibrationNode:
18
+ node: Tree
19
+ date: datetime
17
20
 
18
21
 
19
22
  Color = str | tuple[float, float, float] | tuple[float, float, float, float]
20
23
 
21
24
 
22
- def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color]) -> Axes:
25
+ def draw_tree(
26
+ tree: Tree,
27
+ ax: Axes | None = None,
28
+ colors: Color | dict[Tree, Color] = "black",
29
+ backward_time: bool = False,
30
+ ) -> Axes:
31
+ """
32
+ Draw a phylogenetic tree with colored branches.
33
+
34
+ Parameters
35
+ ----------
36
+ tree : Tree
37
+ The phylogenetic tree to draw.
38
+ ax : Axes | None, optional
39
+ The matplotlib Axes to draw on. If None, uses the current Axes.
40
+ colors : Color | dict[Tree, Color], optional
41
+ A single color for all branches or a dictionary mapping each node to a color.
42
+ backward_time : bool, optional
43
+ If True, the x-axis is inverted to represent time going backward.
44
+
45
+ Returns
46
+ -------
47
+ Axes
48
+ The Axes with the drawn tree.
49
+ """
50
+ if ax is None:
51
+ ax = plt.gca()
52
+
23
53
  if not isinstance(colors, dict):
24
54
  colors = {node: colors for node in tree}
25
55
 
@@ -28,122 +58,633 @@ def _draw_colored_tree(tree: Tree, ax: Axes, colors: Color | dict[Tree, Color])
28
58
  if any(node.branch_length is None for node in tree.iter_descendants())
29
59
  else get_node_depths(tree)
30
60
  )
61
+ if backward_time:
62
+ max_x = max(xs.values())
63
+ xs = {node: max_x - x for node, x in xs.items()}
64
+
31
65
  ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
32
66
  for node in tree.postorder_traversal():
33
67
  if node.is_internal():
34
68
  ys[node] = sum(ys[child] for child in node.children) / len(node.children)
35
69
 
70
+ if tree.branch_length is not None:
71
+ xmin = xs[tree] + tree.branch_length if backward_time else 0
72
+ ax.hlines(y=ys[tree], xmin=xmin, xmax=xs[tree], color=colors[tree]) # pyright: ignore
36
73
  for node in tree:
37
74
  x1, y1 = xs[node], ys[node]
38
- if node.parent is None:
39
- ax.hlines(y=y1, xmin=0, xmax=x1, color=colors[node]) # pyright: ignore
40
- continue
41
- x0, y0 = xs[node.parent], ys[node.parent]
42
- ax.vlines(x=x0, ymin=y0, ymax=y1, color=colors[node]) # pyright: ignore
43
- ax.hlines(y=y1, xmin=x0, xmax=x1, color=colors[node]) # pyright: ignore
75
+ for child in node.children:
76
+ x2, y2 = xs[child], ys[child]
77
+ ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
78
+ ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
79
+
80
+ if backward_time:
81
+ ax.invert_xaxis()
44
82
 
45
83
  ax.set_yticks([]) # pyright: ignore
46
84
  return ax
47
85
 
48
86
 
49
- def draw_tree(
87
+ def _depth_to_date(
88
+ depth: float, calibration_nodes: tuple[CalibrationNode, CalibrationNode]
89
+ ) -> datetime:
90
+ """
91
+ Convert a depth value to a date using linear interpolation between two calibration nodes.
92
+
93
+ Parameters
94
+ ----------
95
+ depth : float
96
+ The depth value to convert.
97
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
98
+ Two calibration nodes defining the mapping from depth to date.
99
+
100
+ Returns
101
+ -------
102
+ datetime
103
+ The interpolated date corresponding to the given depth.
104
+ """
105
+ node1, node2 = calibration_nodes
106
+ depth1, depth2 = node1.node.depth, node2.node.depth
107
+ date1, date2 = node1.date, node2.date
108
+ return date1 + (depth - depth1) * (date2 - date1) / (depth2 - depth1)
109
+
110
+
111
+ def draw_dated_tree(
112
+ tree: Tree,
113
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
114
+ ax: Axes | None = None,
115
+ colors: Color | dict[Tree, Color] = "black",
116
+ ) -> Axes:
117
+ """
118
+ Draw a phylogenetic tree with branches positioned according to calibrated dates.
119
+
120
+ Parameters
121
+ ----------
122
+ tree : Tree
123
+ The phylogenetic tree to draw.
124
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
125
+ Two calibration nodes defining the mapping from depth to date.
126
+ ax : Axes | None, optional
127
+ The matplotlib Axes to draw on. If None, uses the current Axes.
128
+ colors : Color | dict[Tree, Color], optional
129
+ A single color for all branches or a dictionary mapping each node to a color.
130
+
131
+ Returns
132
+ -------
133
+ Axes
134
+ The Axes with the drawn dated tree.
135
+ """
136
+ if ax is None:
137
+ ax = plt.gca()
138
+
139
+ if not isinstance(colors, dict):
140
+ colors = {node: colors for node in tree}
141
+
142
+ xs = {
143
+ node: _depth_to_date(depth=depth, calibration_nodes=calibration_nodes)
144
+ for node, depth in get_node_depths(tree).items()
145
+ }
146
+
147
+ ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
148
+ for node in tree.postorder_traversal():
149
+ if node.is_internal():
150
+ ys[node] = sum(ys[child] for child in node.children) / len(node.children)
151
+
152
+ if tree.branch_length is not None:
153
+ origin_date = _depth_to_date(depth=0, calibration_nodes=calibration_nodes)
154
+ ax.hlines( # pyright: ignore
155
+ y=ys[tree],
156
+ xmin=mdates.date2num(origin_date), # pyright: ignore
157
+ xmax=mdates.date2num(xs[tree]), # pyright: ignore
158
+ color=colors[tree],
159
+ )
160
+ for node in tree:
161
+ x1, y1 = xs[node], ys[node]
162
+ for child in node.children:
163
+ x2, y2 = xs[child], ys[child]
164
+ ax.hlines( # pyright: ignore
165
+ y=y2,
166
+ xmin=mdates.date2num(x1), # pyright: ignore
167
+ xmax=mdates.date2num(x2), # pyright: ignore
168
+ color=colors[child],
169
+ )
170
+ ax.vlines(x=mdates.date2num(x1), ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
171
+
172
+ ax.xaxis.set_major_locator(mdates.AutoDateLocator())
173
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
174
+ ax.tick_params(axis="x", labelrotation=45) # pyright: ignore
175
+
176
+ ax.set_yticks([]) # pyright: ignore
177
+ return ax
178
+
179
+
180
+ def _init_colored_tree_categorical(
50
181
  tree: Tree,
182
+ color_by: str,
51
183
  ax: Axes | None = None,
52
- color_by: str | dict[str, Any] | None = None,
53
- coloring: str | Coloring | None = None,
54
184
  default_color: Color = "black",
55
- colormap: str | Colormap | None = None,
56
- vmin: float | None = None,
57
- vmax: float | None = None,
185
+ colormap: str | Colormap = "tab20",
186
+ show_legend: bool = True,
187
+ labels: dict[Any, str] | None = None,
188
+ legend_kwargs: dict[str, Any] | None = None,
189
+ ) -> tuple[Axes, dict[Tree, Color]]:
190
+ """
191
+ Initialize colors for drawing a tree based on categorical metadata.
192
+
193
+ Parameters
194
+ ----------
195
+ tree : Tree
196
+ The phylogenetic tree.
197
+ color_by : str
198
+ The metadata key to color branches by.
199
+ ax : Axes | None, optional
200
+ The matplotlib Axes to draw on. If None, uses the current Axes.
201
+ default_color : Color, optional
202
+ The color to use for nodes without the specified metadata.
203
+ colormap : str | Colormap, optional
204
+ The colormap to use for coloring categories. Defaults to 'tab20'.
205
+ show_legend : bool, optional
206
+ Whether to display a legend for the categories.
207
+ labels : dict[Any, str] | None, optional
208
+ A mapping from category values to labels for the legend.
209
+ legend_kwargs : dict[str, Any] | None, optional
210
+ Additional keyword arguments to pass to the legend.
211
+
212
+ Returns
213
+ -------
214
+ tuple[Axes, dict[Tree, Color]]
215
+ The Axes and a dictionary mapping each node to its assigned color.
216
+ """
217
+ if ax is None:
218
+ ax = plt.gca()
219
+
220
+ if isinstance(colormap, str):
221
+ colormap = plt.get_cmap(colormap)
222
+
223
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
224
+ feature_colors = {
225
+ f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
226
+ }
227
+ colors = {
228
+ node: feature_colors[features[node]] if node in features else default_color
229
+ for node in tree
230
+ }
231
+
232
+ if show_legend:
233
+ legend_handles = [
234
+ mpatches.Patch(
235
+ color=feature_colors[f],
236
+ label=str(f) if labels is None else labels[f],
237
+ )
238
+ for f in feature_colors
239
+ ]
240
+ if any(color_by not in node.metadata for node in tree):
241
+ legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
242
+ if legend_kwargs is None:
243
+ legend_kwargs = {}
244
+ ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
245
+
246
+ return ax, colors
247
+
248
+
249
+ def draw_colored_tree_categorical(
250
+ tree: Tree,
251
+ color_by: str,
252
+ ax: Axes | None = None,
253
+ backward_time: bool = False,
254
+ default_color: Color = "black",
255
+ colormap: str | Colormap = "tab20",
58
256
  show_legend: bool = True,
59
- labels: dict[Any, Any] | None = None,
257
+ labels: dict[Any, str] | None = None,
60
258
  legend_kwargs: dict[str, Any] | None = None,
259
+ ):
260
+ """
261
+ Draw a phylogenetic tree with branches colored based on categorical metadata.
262
+
263
+ Parameters
264
+ ----------
265
+ tree : Tree
266
+ The phylogenetic tree to draw.
267
+ color_by : str
268
+ The metadata key to color branches by.
269
+ ax : Axes | None, optional
270
+ The matplotlib Axes to draw on. If None, uses the current Axes.
271
+ backward_time : bool, optional
272
+ If True, the x-axis is inverted to represent time going backward.
273
+ default_color : Color, optional
274
+ The color to use for nodes without the specified metadata.
275
+ colormap : str | Colormap, optional
276
+ The colormap to use for coloring categories. Defaults to 'tab20'.
277
+ show_legend : bool, optional
278
+ Whether to display a legend for the categories.
279
+ labels : dict[Any, str] | None, optional
280
+ A mapping from category values to labels for the legend.
281
+ legend_kwargs : dict[str, Any] | None, optional
282
+ Additional keyword arguments to pass to the legend.
283
+
284
+ Returns
285
+ -------
286
+ Axes
287
+ The Axes with the drawn colored tree.
288
+ """
289
+ ax, colors = _init_colored_tree_categorical(
290
+ tree=tree,
291
+ color_by=color_by,
292
+ ax=ax,
293
+ default_color=default_color,
294
+ colormap=colormap,
295
+ show_legend=show_legend,
296
+ labels=labels,
297
+ legend_kwargs=legend_kwargs,
298
+ )
299
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
300
+
301
+
302
+ def draw_colored_dated_tree_categorical(
303
+ tree: Tree,
304
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
305
+ color_by: str,
306
+ ax: Axes | None = None,
307
+ default_color: Color = "black",
308
+ colormap: str | Colormap = "tab20",
309
+ show_legend: bool = True,
310
+ labels: dict[Any, str] | None = None,
311
+ legend_kwargs: dict[str, Any] | None = None,
312
+ ) -> Axes:
313
+ """
314
+ Draw a dated phylogenetic tree with branches colored based on categorical metadata.
315
+
316
+ Parameters
317
+ ----------
318
+ tree : Tree
319
+ The phylogenetic tree to draw.
320
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
321
+ Two calibration nodes defining the mapping from depth to date.
322
+ color_by : str
323
+ The metadata key to color branches by.
324
+ ax : Axes | None, optional
325
+ The matplotlib Axes to draw on. If None, uses the current Axes.
326
+ default_color : Color, optional
327
+ The color to use for nodes without the specified metadata.
328
+ colormap : str | Colormap, optional
329
+ The colormap to use for coloring categories. Defaults to 'tab20'.
330
+ show_legend : bool, optional
331
+ Whether to display a legend for the categories.
332
+ labels : dict[Any, str] | None, optional
333
+ A mapping from category values to labels for the legend.
334
+ legend_kwargs : dict[str, Any] | None, optional
335
+ Additional keyword arguments to pass to the legend.
336
+
337
+ Returns
338
+ -------
339
+ Axes
340
+ The Axes with the drawn colored dated tree.
341
+ """
342
+ ax, colors = _init_colored_tree_categorical(
343
+ tree=tree,
344
+ color_by=color_by,
345
+ ax=ax,
346
+ default_color=default_color,
347
+ colormap=colormap,
348
+ show_legend=show_legend,
349
+ labels=labels,
350
+ legend_kwargs=legend_kwargs,
351
+ )
352
+ return draw_dated_tree(
353
+ tree=tree, calibration_nodes=calibration_nodes, ax=ax, colors=colors
354
+ )
355
+
356
+
357
+ @overload
358
+ def _init_colored_tree_continuous(
359
+ tree: Tree,
360
+ color_by: str,
361
+ ax: Axes | None = ...,
362
+ default_color: Color = ...,
363
+ colormap: str | Colormap = ...,
364
+ vmin: float | None = ...,
365
+ vmax: float | None = ...,
366
+ *,
367
+ show_hist: Literal[False],
368
+ hist_kwargs: dict[str, Any] | None = ...,
369
+ hist_axes_kwargs: dict[str, Any] | None = ...,
370
+ ) -> tuple[Axes, dict[Tree, Color]]: ...
371
+ @overload
372
+ def _init_colored_tree_continuous(
373
+ tree: Tree,
374
+ color_by: str,
375
+ ax: Axes | None = ...,
376
+ default_color: Color = ...,
377
+ colormap: str | Colormap = ...,
378
+ vmin: float | None = ...,
379
+ vmax: float | None = ...,
380
+ *,
381
+ show_hist: Literal[True],
382
+ hist_kwargs: dict[str, Any] | None = ...,
383
+ hist_axes_kwargs: dict[str, Any] | None = ...,
384
+ ) -> tuple[Axes, dict[Tree, Color], Axes]: ...
385
+ def _init_colored_tree_continuous(
386
+ tree: Tree,
387
+ color_by: str,
388
+ ax: Axes | None = None,
389
+ default_color: Color = "black",
390
+ colormap: str | Colormap = "viridis",
391
+ vmin: float | None = None,
392
+ vmax: float | None = None,
393
+ *,
61
394
  show_hist: bool = True,
62
395
  hist_kwargs: dict[str, Any] | None = None,
63
396
  hist_axes_kwargs: dict[str, Any] | None = None,
64
- ) -> Axes | tuple[Axes, Axes]:
397
+ ) -> tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]:
398
+ """
399
+ Initialize colors for drawing a tree based on continuous metadata.
400
+
401
+ Parameters
402
+ ----------
403
+ tree : Tree
404
+ The phylogenetic tree.
405
+ color_by : str
406
+ The metadata key to color branches by.
407
+ ax : Axes | None, optional
408
+ The matplotlib Axes to draw on. If None, uses the current Axes.
409
+ default_color : Color, optional
410
+ The color to use for nodes without the specified metadata.
411
+ colormap : str | Colormap, optional
412
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
413
+ vmin : float | None, optional
414
+ The minimum value for normalization. If None, uses the minimum of the data.
415
+ vmax : float | None, optional
416
+ The maximum value for normalization. If None, uses the maximum of the data.
417
+ show_hist : bool, optional
418
+ Whether to display a histogram of the continuous values.
419
+ hist_kwargs : dict[str, Any] | None, optional
420
+ Additional keyword arguments to pass to the histogram.
421
+ hist_axes_kwargs : dict[str, Any] | None, optional
422
+ Additional keyword arguments to define the histogram Axes.
423
+
424
+ Returns
425
+ -------
426
+ tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]
427
+ The Axes, a dictionary mapping each node to its assigned color,
428
+ and optionally the histogram Axes if `show_hist` is True.
429
+ """
65
430
  if ax is None:
66
431
  ax = plt.gca()
67
432
 
68
- if color_by is None:
69
- return _draw_colored_tree(tree, ax, colors=default_color)
433
+ if isinstance(colormap, str):
434
+ colormap = plt.get_cmap(colormap)
70
435
 
71
- if isinstance(color_by, str):
72
- features = {node: node[color_by] for node in tree if color_by in node.metadata}
73
- else:
74
- features = {node: color_by[node.name] for node in tree if node.name in color_by}
436
+ features = {node: node[color_by] for node in tree if color_by in node.metadata}
75
437
  values = list(features.values())
438
+ vmin = min(values) if vmin is None else vmin
439
+ vmax = max(values) if vmax is None else vmax
440
+ norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
441
+ colors = {
442
+ node: colormap(norm(float(features[node])))
443
+ if node in features
444
+ else default_color
445
+ for node in tree
446
+ }
447
+
448
+ if show_hist:
449
+ default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
450
+ if hist_axes_kwargs is not None:
451
+ default_hist_axes_kwargs.update(hist_axes_kwargs)
452
+ hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
453
+
454
+ hist_kwargs = {} if hist_kwargs is None else hist_kwargs
455
+ _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
456
+
457
+ for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
458
+ midpoint = (b0 + b1) / 2 # pyright: ignore
459
+ patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
460
+ return ax, colors, hist_ax # pyright: ignore
461
+
462
+ sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
463
+ ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
464
+ return ax, colors
465
+
466
+
467
+ @overload
468
+ def draw_colored_tree_continuous(
469
+ tree: Tree,
470
+ color_by: str,
471
+ ax: Axes | None = None,
472
+ backward_time: bool = False,
473
+ default_color: Color = "black",
474
+ colormap: str | Colormap = "viridis",
475
+ vmin: float | None = None,
476
+ vmax: float | None = None,
477
+ *,
478
+ show_hist: Literal[False],
479
+ hist_kwargs: dict[str, Any] | None = None,
480
+ hist_axes_kwargs: dict[str, Any] | None = None,
481
+ ) -> Axes: ...
482
+ @overload
483
+ def draw_colored_tree_continuous(
484
+ tree: Tree,
485
+ color_by: str,
486
+ ax: Axes | None = None,
487
+ backward_time: bool = False,
488
+ default_color: Color = "black",
489
+ colormap: str | Colormap = "viridis",
490
+ vmin: float | None = None,
491
+ vmax: float | None = None,
492
+ *,
493
+ show_hist: Literal[True],
494
+ hist_kwargs: dict[str, Any] | None = None,
495
+ hist_axes_kwargs: dict[str, Any] | None = None,
496
+ ) -> tuple[Axes, Axes]: ...
497
+ def draw_colored_tree_continuous(
498
+ tree: Tree,
499
+ color_by: str,
500
+ ax: Axes | None = None,
501
+ backward_time: bool = False,
502
+ default_color: Color = "black",
503
+ colormap: str | Colormap = "viridis",
504
+ vmin: float | None = None,
505
+ vmax: float | None = None,
506
+ show_hist: bool = True,
507
+ hist_kwargs: dict[str, Any] | None = None,
508
+ hist_axes_kwargs: dict[str, Any] | None = None,
509
+ ) -> Axes | tuple[Axes, Axes]:
510
+ """
511
+ Draw a phylogenetic tree with branches colored based on continuous metadata.
512
+
513
+ Parameters
514
+ ----------
515
+ tree : Tree
516
+ The phylogenetic tree to draw.
517
+ color_by : str
518
+ The metadata key to color branches by.
519
+ ax : Axes | None, optional
520
+ The matplotlib Axes to draw on. If None, uses the current Axes.
521
+ backward_time : bool, optional
522
+ If True, the x-axis is inverted to represent time going backward.
523
+ default_color : Color, optional
524
+ The color to use for nodes without the specified metadata.
525
+ colormap : str | Colormap, optional
526
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
527
+ vmin : float | None, optional
528
+ The minimum value for normalization. If None, uses the minimum of the data.
529
+ vmax : float | None, optional
530
+ The maximum value for normalization. If None, uses the maximum of the data.
531
+ show_hist : bool, optional
532
+ Whether to display a histogram of the continuous values.
533
+ hist_kwargs : dict[str, Any] | None, optional
534
+ Additional keyword arguments to pass to the histogram.
535
+ hist_axes_kwargs : dict[str, Any] | None, optional
536
+ Additional keyword arguments to define the histogram Axes.
76
537
 
77
- if coloring is None:
78
- coloring = (
79
- Coloring.CONTINUOUS
80
- if any(isinstance(f, float) for f in values)
81
- else Coloring.DISCRETE
538
+ Returns
539
+ -------
540
+ Axes | tuple[Axes, Axes]
541
+ The Axes with the drawn colored tree,
542
+ and optionally the histogram Axes if `show_hist` is True.
543
+ """
544
+ if show_hist:
545
+ ax, colors, hist_ax = _init_colored_tree_continuous(
546
+ tree=tree,
547
+ color_by=color_by,
548
+ ax=ax,
549
+ default_color=default_color,
550
+ colormap=colormap,
551
+ vmin=vmin,
552
+ vmax=vmax,
553
+ show_hist=show_hist,
554
+ hist_kwargs=hist_kwargs,
555
+ hist_axes_kwargs=hist_axes_kwargs,
82
556
  )
83
- if colormap is None:
84
- colormap = "tab20" if coloring == Coloring.DISCRETE else "viridis"
85
- if isinstance(colormap, str):
86
- colormap = plt.get_cmap(colormap)
557
+ return draw_tree(
558
+ tree=tree, ax=ax, colors=colors, backward_time=backward_time
559
+ ), hist_ax
87
560
 
88
- def _get_colors(feature_map: Callable[[Any], Color]) -> dict[Tree, Color]:
89
- return {
90
- node: feature_map(features[node]) if node in features else default_color
91
- for node in tree
92
- }
561
+ ax, colors = _init_colored_tree_continuous(
562
+ tree=tree,
563
+ color_by=color_by,
564
+ ax=ax,
565
+ default_color=default_color,
566
+ colormap=colormap,
567
+ vmin=vmin,
568
+ vmax=vmax,
569
+ show_hist=show_hist,
570
+ hist_kwargs=hist_kwargs,
571
+ hist_axes_kwargs=hist_axes_kwargs,
572
+ )
573
+ return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
93
574
 
94
- if coloring == Coloring.DISCRETE:
95
- if any(isinstance(f, float) for f in values):
96
- raise ValueError(
97
- "Discrete coloring selected but feature values are not all categorical."
98
- )
99
- feature_colors = {
100
- f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(values))
101
- }
102
- colors = _get_colors(lambda f: feature_colors[f])
103
-
104
- if show_legend:
105
- legend_handles = [
106
- mpatches.Patch(
107
- color=feature_colors[f],
108
- label=str(f) if labels is None else labels[f],
109
- )
110
- for f in feature_colors
111
- ]
112
- if any(color_by not in node.metadata for node in tree):
113
- legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
114
- if legend_kwargs is None:
115
- legend_kwargs = {}
116
- ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
117
-
118
- return _draw_colored_tree(tree, ax, colors)
119
-
120
- if coloring == Coloring.CONTINUOUS:
121
- vmin = min(values) if vmin is None else vmin
122
- vmax = max(values) if vmax is None else vmax
123
- norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
124
- colors = _get_colors(lambda f: colormap(norm(float(f))))
125
-
126
- if show_hist:
127
- default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
128
- if hist_axes_kwargs is not None:
129
- default_hist_axes_kwargs.update(hist_axes_kwargs)
130
- hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
131
-
132
- hist_kwargs = {} if hist_kwargs is None else hist_kwargs
133
- _, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
134
-
135
- for patch, b0, b1 in zip( # pyright: ignore
136
- patches, bins[:-1], bins[1:] # pyright: ignore
137
- ):
138
- midpoint = (b0 + b1) / 2 # pyright: ignore
139
- patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
140
- return _draw_colored_tree(tree, ax, colors), hist_ax # pyright: ignore
141
-
142
- else:
143
- sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
144
- ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
145
- return _draw_colored_tree(tree, ax, colors)
146
-
147
- raise ValueError(
148
- f"Unknown coloring method: {coloring}. Choices are {list(Coloring)}."
575
+
576
+ @overload
577
+ def draw_colored_dated_tree_continuous(
578
+ tree: Tree,
579
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
580
+ color_by: str,
581
+ ax: Axes | None = None,
582
+ default_color: Color = "black",
583
+ colormap: str | Colormap = "viridis",
584
+ vmin: float | None = None,
585
+ vmax: float | None = None,
586
+ *,
587
+ show_hist: Literal[False],
588
+ hist_kwargs: dict[str, Any] | None = None,
589
+ hist_axes_kwargs: dict[str, Any] | None = None,
590
+ ) -> Axes: ...
591
+ @overload
592
+ def draw_colored_dated_tree_continuous(
593
+ tree: Tree,
594
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
595
+ color_by: str,
596
+ ax: Axes | None = None,
597
+ default_color: Color = "black",
598
+ colormap: str | Colormap = "viridis",
599
+ vmin: float | None = None,
600
+ vmax: float | None = None,
601
+ *,
602
+ show_hist: Literal[True],
603
+ hist_kwargs: dict[str, Any] | None = None,
604
+ hist_axes_kwargs: dict[str, Any] | None = None,
605
+ ) -> tuple[Axes, Axes]: ...
606
+ def draw_colored_dated_tree_continuous(
607
+ tree: Tree,
608
+ calibration_nodes: tuple[CalibrationNode, CalibrationNode],
609
+ color_by: str,
610
+ ax: Axes | None = None,
611
+ default_color: Color = "black",
612
+ colormap: str | Colormap = "viridis",
613
+ vmin: float | None = None,
614
+ vmax: float | None = None,
615
+ show_hist: bool = True,
616
+ hist_kwargs: dict[str, Any] | None = None,
617
+ hist_axes_kwargs: dict[str, Any] | None = None,
618
+ ) -> Axes | tuple[Axes, Axes]:
619
+ """
620
+ Draw a dated phylogenetic tree with branches colored based on continuous metadata.
621
+
622
+ Parameters
623
+ ----------
624
+ tree : Tree
625
+ The phylogenetic tree to draw.
626
+ calibration_nodes : tuple[CalibrationNode, CalibrationNode]
627
+ Two calibration nodes defining the mapping from depth to date.
628
+ color_by : str
629
+ The metadata key to color branches by.
630
+ ax : Axes | None, optional
631
+ The matplotlib Axes to draw on. If None, uses the current Axes.
632
+ default_color : Color, optional
633
+ The color to use for nodes without the specified metadata.
634
+ colormap : str | Colormap, optional
635
+ The colormap to use for coloring continuous values. Defaults to 'viridis'.
636
+ vmin : float | None, optional
637
+ The minimum value for normalization. If None, uses the minimum of the data.
638
+ vmax : float | None, optional
639
+ The maximum value for normalization. If None, uses the maximum of the data.
640
+ show_hist : bool, optional
641
+ Whether to display a histogram of the continuous values.
642
+ hist_kwargs : dict[str, Any] | None, optional
643
+ Additional keyword arguments to pass to the histogram.
644
+ hist_axes_kwargs : dict[str, Any] | None, optional
645
+ Additional keyword arguments to define the histogram Axes.
646
+
647
+ Returns
648
+ -------
649
+ Axes | tuple[Axes, Axes]
650
+ The Axes with the drawn colored dated tree,
651
+ and optionally the histogram Axes if `show_hist` is True.
652
+ """
653
+ if show_hist:
654
+ ax, colors, hist_ax = _init_colored_tree_continuous(
655
+ tree=tree,
656
+ color_by=color_by,
657
+ ax=ax,
658
+ default_color=default_color,
659
+ colormap=colormap,
660
+ vmin=vmin,
661
+ vmax=vmax,
662
+ show_hist=show_hist,
663
+ hist_kwargs=hist_kwargs,
664
+ hist_axes_kwargs=hist_axes_kwargs,
665
+ )
666
+ return draw_dated_tree(
667
+ tree=tree,
668
+ calibration_nodes=calibration_nodes,
669
+ ax=ax,
670
+ colors=colors,
671
+ ), hist_ax
672
+
673
+ ax, colors = _init_colored_tree_continuous(
674
+ tree=tree,
675
+ color_by=color_by,
676
+ ax=ax,
677
+ default_color=default_color,
678
+ colormap=colormap,
679
+ vmin=vmin,
680
+ vmax=vmax,
681
+ show_hist=show_hist,
682
+ hist_kwargs=hist_kwargs,
683
+ hist_axes_kwargs=hist_axes_kwargs,
684
+ )
685
+ return draw_dated_tree(
686
+ tree=tree,
687
+ calibration_nodes=calibration_nodes,
688
+ ax=ax,
689
+ colors=colors,
149
690
  )
phylogenie/io/fasta.py CHANGED
@@ -24,3 +24,11 @@ def load_fasta(
24
24
  chars = next(f).strip()
25
25
  sequences.append(Sequence(id, chars, time))
26
26
  return MSA(sequences)
27
+
28
+
29
+ def dump_fasta(msa: MSA | list[Sequence], fasta_file: str | Path) -> None:
30
+ with open(fasta_file, "w") as f:
31
+ sequences = msa.sequences if isinstance(msa, MSA) else msa
32
+ for seq in sequences:
33
+ f.write(f">{seq.id}\n")
34
+ f.write(f"{seq.chars}\n")
phylogenie/main.py CHANGED
@@ -2,23 +2,44 @@ import os
2
2
  from argparse import ArgumentParser
3
3
  from glob import glob
4
4
 
5
- from pydantic import TypeAdapter
5
+ from pydantic import TypeAdapter, ValidationError
6
6
  from yaml import safe_load
7
7
 
8
8
  from phylogenie.generators import DatasetGeneratorConfig
9
9
  from phylogenie.generators.dataset import DatasetGenerator
10
10
 
11
11
 
12
- def run(config_path: str) -> None:
12
+ def _format_validation_error(e: ValidationError) -> str:
13
+ formatted_errors = [
14
+ f"- {'.'.join(str(loc) for loc in err['loc'])}: {err['msg']} ({err['type']})"
15
+ for err in e.errors()
16
+ ]
17
+ return "\n".join(formatted_errors)
18
+
19
+
20
+ def _generate_from_config_file(config_file: str):
13
21
  adapter: TypeAdapter[DatasetGenerator] = TypeAdapter(DatasetGeneratorConfig)
22
+ with open(config_file, "r") as f:
23
+ try:
24
+ config = safe_load(f)
25
+ except Exception as e:
26
+ print(f"❌ Failed to parse {config_file}: {e}")
27
+ exit(-1)
28
+ try:
29
+ generator = adapter.validate_python(config)
30
+ except ValidationError as e:
31
+ print("❌ Invalid configuration:")
32
+ print(_format_validation_error(e))
33
+ exit(-1)
34
+ generator.generate()
14
35
 
36
+
37
+ def run(config_path: str) -> None:
15
38
  if os.path.isdir(config_path):
16
39
  for config_file in glob(os.path.join(config_path, "**/*.yaml"), recursive=True):
17
- with open(config_file, "r") as f:
18
- adapter.validate_python(safe_load(f)).generate()
40
+ _generate_from_config_file(config_file)
19
41
  else:
20
- with open(config_path, "r") as f:
21
- adapter.validate_python(safe_load(f)).generate()
42
+ _generate_from_config_file(config_path)
22
43
 
23
44
 
24
45
  def main() -> None:
phylogenie/mixins.py CHANGED
@@ -1,39 +1,31 @@
1
+ from collections.abc import Mapping
1
2
  from types import MappingProxyType
2
- from typing import Any, Mapping, Optional
3
+ from typing import Any
3
4
 
4
5
 
5
6
  class MetadataMixin:
6
- """A mixin that provides metadata management with dictionary-like access."""
7
-
8
7
  def __init__(self) -> None:
9
8
  self._metadata: dict[str, Any] = {}
10
9
 
11
10
  @property
12
11
  def metadata(self) -> Mapping[str, Any]:
13
- """Return a read-only view of all metadata."""
14
12
  return MappingProxyType(self._metadata)
15
13
 
16
14
  def set(self, key: str, value: Any) -> None:
17
- """Set or update a metadata value."""
18
15
  self._metadata[key] = value
19
16
 
20
17
  def update(self, metadata: Mapping[str, Any]) -> None:
21
- """Bulk update metadata values."""
22
18
  self._metadata.update(metadata)
23
19
 
24
- def get(self, key: str, default: Optional[Any] = None) -> Any:
25
- """Get a metadata value, returning `default` if not found."""
20
+ def get(self, key: str, default: Any = None) -> Any:
26
21
  return self._metadata.get(key, default)
27
22
 
28
23
  def delete(self, key: str) -> None:
29
- """Delete a metadata if it exists, else do nothing."""
30
24
  self._metadata.pop(key, None)
31
25
 
32
26
  def clear(self) -> None:
33
- """Remove all metadata."""
34
27
  self._metadata.clear()
35
28
 
36
- # Dict-like behavior
37
29
  def __getitem__(self, key: str) -> Any:
38
30
  return self._metadata[key]
39
31
 
@@ -15,7 +15,7 @@ def _parse_translate_block(lines: Iterator[str]) -> dict[str, str]:
15
15
  if ";" in line:
16
16
  return translations
17
17
  else:
18
- raise ValueError(f"Invalid translate line. Expected '<num> <name>'.")
18
+ raise ValueError("Invalid translate line. Expected '<num> <name>'.")
19
19
  translations[match.group(1)] = match.group(2)
20
20
  raise ValueError("Translate block not terminated with ';'.")
21
21
 
@@ -33,7 +33,7 @@ def _parse_trees_block(lines: Iterator[str]) -> dict[str, Tree]:
33
33
  match = re.match(r"^TREE\s*\*?\s+(\S+)\s*=\s*(.+)$", line, re.IGNORECASE)
34
34
  if match is None:
35
35
  raise ValueError(
36
- f"Invalid tree line. Expected 'TREE <name> = <newick>'."
36
+ "Invalid tree line. Expected 'TREE <name> = <newick>'."
37
37
  )
38
38
  name = match.group(1)
39
39
  if name in trees:
@@ -159,8 +159,11 @@ class Tree(MetadataMixin):
159
159
  # Other useful miscellaneous methods.
160
160
 
161
161
  def ladderize(self, key: Callable[["Tree"], Any] | None = None) -> None:
162
+ def _default_key(node: Tree) -> int:
163
+ return node.n_leaves
164
+
162
165
  if key is None:
163
- key = lambda node: node.n_leaves
166
+ key = _default_key
164
167
  self._children.sort(key=key)
165
168
  for child in self.children:
166
169
  child.ladderize(key)
@@ -169,7 +172,7 @@ class Tree(MetadataMixin):
169
172
  for node in self:
170
173
  if node.name == name:
171
174
  return node
172
- raise ValueError(f"Node with name {name} not found.")
175
+ raise ValueError(f"Node {name} not found.")
173
176
 
174
177
  def copy(self):
175
178
  new_tree = Tree(self.name, self.branch_length)
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: phylogenie
3
- Version: 3.1.4
4
- Summary: Add your description here
3
+ Version: 3.1.6
4
+ Summary: Generate phylogenetic datasets with minimal setup effort
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
7
7
  License-File: LICENSE.txt
@@ -1,7 +1,7 @@
1
1
  phylogenie/__init__.py,sha256=-4GKO2EKgc2GiOhzVRI0SQk6wQZ4AlP7Ron_Lp-s_r0,2913
2
- phylogenie/draw.py,sha256=nT3FOHSm85G23qz2xxYf2MesEDonWRSxfKGJkVboTKY,5580
3
- phylogenie/main.py,sha256=DstKWUySyDXdcPN_P8O3Mmoau4IjUxCQ4iLWDVuAcws,1054
4
- phylogenie/mixins.py,sha256=kwxcHGYSS-Frk-r0nwuyILYS1g6NjSOhbUKEYt2g1B0,1319
2
+ phylogenie/draw.py,sha256=djG0cG9mmNFZ7vf3jo8Ei3bNBJVqujp6gtZtwPxyfSY,23341
3
+ phylogenie/main.py,sha256=ry3B3HiwibZG3_qB58T5UhWy5dp6neYUtSqzL9LrSkA,1698
4
+ phylogenie/mixins.py,sha256=wMwqP6zkqME9eMyzx5FS6-p9X8yW09jIC8jge8pHlkk,907
5
5
  phylogenie/msa.py,sha256=JDGyZUsAq6-m-SQjoCDjAkAZIxfgyl_PDIhdYn5HOow,2064
6
6
  phylogenie/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  phylogenie/typeguards.py,sha256=JtqmbEWJZBRHbWgCvcl6nrWm3VcBfzRbklbTBYHItn0,1325
@@ -14,7 +14,7 @@ phylogenie/generators/factories.py,sha256=2mTFdFbbLyV3v79JaOEVtqLOmxQHaOUv1S-Y3v
14
14
  phylogenie/generators/trees.py,sha256=8dO1CkU34E6mmMAHrYqiLV_VA8r54cSEOo-UzoHiN20,10467
15
15
  phylogenie/generators/typeguards.py,sha256=yj4VkhOaUXJ2OrY-6zhOeY9C4yKIQxjZtk2d-vIxttQ,828
16
16
  phylogenie/io/__init__.py,sha256=eWDU6YDqAdx6TZlUMmMbsfO8gP3i5HP_cfUe_-0x2FA,69
17
- phylogenie/io/fasta.py,sha256=CUFO06m7wClprarsMheZojM4Os2NQf3ALYUXSWzfNL0,869
17
+ phylogenie/io/fasta.py,sha256=kx9uVATLpzpXdhhNNMvMpB5Vdwh9CTWepTGachvEwV4,1154
18
18
  phylogenie/skyline/__init__.py,sha256=7pF4CUb4ZCLzNYJNhOjpuTOLTRhlK7L6ugfccNqjIGo,620
19
19
  phylogenie/skyline/matrix.py,sha256=v4SitY7VbXprqlqQckjWTzW5hwRmCyIF595R6IJMxWw,9268
20
20
  phylogenie/skyline/parameter.py,sha256=TVqkqirGXNN-VP8hnIJACPkOxUan6LkGa5o_JcPfwbY,4834
@@ -23,7 +23,7 @@ phylogenie/treesimulator/__init__.py,sha256=Qz__cFookXuEjq8AUp1A6XRON0qQ_xX-q2Q5
23
23
  phylogenie/treesimulator/features.py,sha256=XbuwGw8xjGs2lNhJvvUUvXVtheSTBaSN6qj39tWYEro,1391
24
24
  phylogenie/treesimulator/gillespie.py,sha256=ey2hdpJOSpNW88duwK7wTAdYSTnSuTSZ_yhZv9MlNHo,5323
25
25
  phylogenie/treesimulator/model.py,sha256=L0RsL6H1ynFDPecULniSs4Cs8dvz87ovviQOXFy5Qt0,4580
26
- phylogenie/treesimulator/tree.py,sha256=-yMW14018x9dw45TonS6nRlzWcwXcaHv3Jn5HYriLQQ,6009
26
+ phylogenie/treesimulator/tree.py,sha256=DEdzCh4vABq2f095beh3tD3_aee7EyXPDSjcyHKgKLg,6064
27
27
  phylogenie/treesimulator/utils.py,sha256=OxZwVHxN004Jf-kYZ_GfJgIY0beo-0tYq80CuFGQt-M,3416
28
28
  phylogenie/treesimulator/events/__init__.py,sha256=w2tJ0D2WB5AiCbr3CsKN6vdADueiAEMzd_ve0rpa4zg,939
29
29
  phylogenie/treesimulator/events/base.py,sha256=JQKYUZmhB2Q-WQOy2ULGKQiabsMz-JvwMVfDoa3ZKyo,1170
@@ -32,10 +32,10 @@ phylogenie/treesimulator/events/core.py,sha256=bhgQgi5L-oaHsoWJmUOsTTzWxi0POYxVL
32
32
  phylogenie/treesimulator/events/mutations.py,sha256=8Nqa2fg7fwaVNe5XSkGDSwp9pIKQ7XaBQCCj-LYlfzA,3666
33
33
  phylogenie/treesimulator/io/__init__.py,sha256=rfP-zp8SP8baq5_4dPAr10WH0W6KfoMCxdTZDCSXtzE,185
34
34
  phylogenie/treesimulator/io/newick.py,sha256=8Pr_jixByPOaVch18w-rFt62HYy0U97YMu0H-QSwIy0,3449
35
- phylogenie/treesimulator/io/nexus.py,sha256=tPZRPejnG0OEE7YTorwvyZe2vnZHv4KiC9f-N7Kx6qg,1849
36
- phylogenie-3.1.4.dist-info/licenses/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
37
- phylogenie-3.1.4.dist-info/METADATA,sha256=en3CvCPhiUbQPZTAUYWpx8nQze6oqS-KQJJa3xGdvbI,5163
38
- phylogenie-3.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- phylogenie-3.1.4.dist-info/entry_points.txt,sha256=BBH8LoReHnNFnvq4sROEsVFegfkKJ6c_oHZ7bgK7Jl4,52
40
- phylogenie-3.1.4.dist-info/top_level.txt,sha256=1YGZJhKA9tN9qI0Hcj6Cn_sOoDpba0HQlNcgQTjMD-8,11
41
- phylogenie-3.1.4.dist-info/RECORD,,
35
+ phylogenie/treesimulator/io/nexus.py,sha256=zqT9dzj413z_s0hqp3Cdq5NMO6lv-zuuaJlaqzaqaB8,1847
36
+ phylogenie-3.1.6.dist-info/licenses/LICENSE.txt,sha256=NUrDqElK-eD3I0WqC004CJsy6cs0JgsAoebDv_42-pw,1071
37
+ phylogenie-3.1.6.dist-info/METADATA,sha256=az3fV29mfXpVsyxK5PDssutWcvaLnu_MLfHO4pgyBas,5194
38
+ phylogenie-3.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ phylogenie-3.1.6.dist-info/entry_points.txt,sha256=BBH8LoReHnNFnvq4sROEsVFegfkKJ6c_oHZ7bgK7Jl4,52
40
+ phylogenie-3.1.6.dist-info/top_level.txt,sha256=1YGZJhKA9tN9qI0Hcj6Cn_sOoDpba0HQlNcgQTjMD-8,11
41
+ phylogenie-3.1.6.dist-info/RECORD,,