phylogenie 2.1.4__py3-none-any.whl → 3.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phylogenie/__init__.py +60 -14
- phylogenie/draw.py +690 -0
- phylogenie/generators/alisim.py +12 -12
- phylogenie/generators/configs.py +26 -4
- phylogenie/generators/dataset.py +3 -3
- phylogenie/generators/factories.py +38 -12
- phylogenie/generators/trees.py +48 -47
- phylogenie/io/__init__.py +3 -0
- phylogenie/io/fasta.py +34 -0
- phylogenie/main.py +27 -10
- phylogenie/mixins.py +33 -0
- phylogenie/skyline/matrix.py +11 -7
- phylogenie/skyline/parameter.py +12 -4
- phylogenie/skyline/vector.py +12 -6
- phylogenie/treesimulator/__init__.py +36 -3
- phylogenie/treesimulator/events/__init__.py +5 -5
- phylogenie/treesimulator/events/base.py +39 -0
- phylogenie/treesimulator/events/contact_tracing.py +38 -23
- phylogenie/treesimulator/events/core.py +21 -12
- phylogenie/treesimulator/events/mutations.py +46 -46
- phylogenie/treesimulator/features.py +49 -0
- phylogenie/treesimulator/gillespie.py +59 -55
- phylogenie/treesimulator/io/__init__.py +4 -0
- phylogenie/treesimulator/io/newick.py +104 -0
- phylogenie/treesimulator/io/nexus.py +50 -0
- phylogenie/treesimulator/model.py +25 -49
- phylogenie/treesimulator/tree.py +196 -0
- phylogenie/treesimulator/utils.py +108 -0
- phylogenie/typings.py +3 -3
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/METADATA +13 -15
- phylogenie-3.1.7.dist-info/RECORD +41 -0
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info}/WHEEL +2 -1
- phylogenie-3.1.7.dist-info/entry_points.txt +2 -0
- phylogenie-3.1.7.dist-info/top_level.txt +1 -0
- phylogenie/io.py +0 -107
- phylogenie/tree.py +0 -92
- phylogenie/utils.py +0 -17
- phylogenie-2.1.4.dist-info/RECORD +0 -32
- phylogenie-2.1.4.dist-info/entry_points.txt +0 -3
- {phylogenie-2.1.4.dist-info → phylogenie-3.1.7.dist-info/licenses}/LICENSE.txt +0 -0
phylogenie/draw.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Literal, overload
|
|
4
|
+
|
|
5
|
+
import matplotlib.colors as mcolors
|
|
6
|
+
import matplotlib.dates as mdates
|
|
7
|
+
import matplotlib.patches as mpatches
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
from matplotlib.axes import Axes
|
|
10
|
+
from matplotlib.colors import Colormap
|
|
11
|
+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes # pyright: ignore
|
|
12
|
+
|
|
13
|
+
from phylogenie.treesimulator import Tree, get_node_depth_levels, get_node_depths
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CalibrationNode:
|
|
18
|
+
node: Tree
|
|
19
|
+
date: datetime
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Color = str | tuple[float, float, float] | tuple[float, float, float, float]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def draw_tree(
|
|
26
|
+
tree: Tree,
|
|
27
|
+
ax: Axes | None = None,
|
|
28
|
+
colors: Color | dict[Tree, Color] = "black",
|
|
29
|
+
backward_time: bool = False,
|
|
30
|
+
) -> Axes:
|
|
31
|
+
"""
|
|
32
|
+
Draw a phylogenetic tree with colored branches.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
tree : Tree
|
|
37
|
+
The phylogenetic tree to draw.
|
|
38
|
+
ax : Axes | None, optional
|
|
39
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
40
|
+
colors : Color | dict[Tree, Color], optional
|
|
41
|
+
A single color for all branches or a dictionary mapping each node to a color.
|
|
42
|
+
backward_time : bool, optional
|
|
43
|
+
If True, the x-axis is inverted to represent time going backward.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
Axes
|
|
48
|
+
The Axes with the drawn tree.
|
|
49
|
+
"""
|
|
50
|
+
if ax is None:
|
|
51
|
+
ax = plt.gca()
|
|
52
|
+
|
|
53
|
+
if not isinstance(colors, dict):
|
|
54
|
+
colors = {node: colors for node in tree}
|
|
55
|
+
|
|
56
|
+
xs = (
|
|
57
|
+
get_node_depth_levels(tree)
|
|
58
|
+
if any(node.branch_length is None for node in tree.iter_descendants())
|
|
59
|
+
else get_node_depths(tree)
|
|
60
|
+
)
|
|
61
|
+
if backward_time:
|
|
62
|
+
max_x = max(xs.values())
|
|
63
|
+
xs = {node: max_x - x for node, x in xs.items()}
|
|
64
|
+
|
|
65
|
+
ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
|
|
66
|
+
for node in tree.postorder_traversal():
|
|
67
|
+
if node.is_internal():
|
|
68
|
+
ys[node] = sum(ys[child] for child in node.children) / len(node.children)
|
|
69
|
+
|
|
70
|
+
if tree.branch_length is not None:
|
|
71
|
+
xmin = xs[tree] + tree.branch_length if backward_time else 0
|
|
72
|
+
ax.hlines(y=ys[tree], xmin=xmin, xmax=xs[tree], color=colors[tree]) # pyright: ignore
|
|
73
|
+
for node in tree:
|
|
74
|
+
x1, y1 = xs[node], ys[node]
|
|
75
|
+
for child in node.children:
|
|
76
|
+
x2, y2 = xs[child], ys[child]
|
|
77
|
+
ax.hlines(y=y2, xmin=x1, xmax=x2, color=colors[child]) # pyright: ignore
|
|
78
|
+
ax.vlines(x=x1, ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
|
|
79
|
+
|
|
80
|
+
if backward_time:
|
|
81
|
+
ax.invert_xaxis()
|
|
82
|
+
|
|
83
|
+
ax.set_yticks([]) # pyright: ignore
|
|
84
|
+
return ax
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _depth_to_date(
|
|
88
|
+
depth: float, calibration_nodes: tuple[CalibrationNode, CalibrationNode]
|
|
89
|
+
) -> datetime:
|
|
90
|
+
"""
|
|
91
|
+
Convert a depth value to a date using linear interpolation between two calibration nodes.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
depth : float
|
|
96
|
+
The depth value to convert.
|
|
97
|
+
calibration_nodes : tuple[CalibrationNode, CalibrationNode]
|
|
98
|
+
Two calibration nodes defining the mapping from depth to date.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
datetime
|
|
103
|
+
The interpolated date corresponding to the given depth.
|
|
104
|
+
"""
|
|
105
|
+
node1, node2 = calibration_nodes
|
|
106
|
+
depth1, depth2 = node1.node.depth, node2.node.depth
|
|
107
|
+
date1, date2 = node1.date, node2.date
|
|
108
|
+
return date1 + (depth - depth1) * (date2 - date1) / (depth2 - depth1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def draw_dated_tree(
|
|
112
|
+
tree: Tree,
|
|
113
|
+
calibration_nodes: tuple[CalibrationNode, CalibrationNode],
|
|
114
|
+
ax: Axes | None = None,
|
|
115
|
+
colors: Color | dict[Tree, Color] = "black",
|
|
116
|
+
) -> Axes:
|
|
117
|
+
"""
|
|
118
|
+
Draw a phylogenetic tree with branches positioned according to calibrated dates.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
tree : Tree
|
|
123
|
+
The phylogenetic tree to draw.
|
|
124
|
+
calibration_nodes : tuple[CalibrationNode, CalibrationNode]
|
|
125
|
+
Two calibration nodes defining the mapping from depth to date.
|
|
126
|
+
ax : Axes | None, optional
|
|
127
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
128
|
+
colors : Color | dict[Tree, Color], optional
|
|
129
|
+
A single color for all branches or a dictionary mapping each node to a color.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
Axes
|
|
134
|
+
The Axes with the drawn dated tree.
|
|
135
|
+
"""
|
|
136
|
+
if ax is None:
|
|
137
|
+
ax = plt.gca()
|
|
138
|
+
|
|
139
|
+
if not isinstance(colors, dict):
|
|
140
|
+
colors = {node: colors for node in tree}
|
|
141
|
+
|
|
142
|
+
xs = {
|
|
143
|
+
node: _depth_to_date(depth=depth, calibration_nodes=calibration_nodes)
|
|
144
|
+
for node, depth in get_node_depths(tree).items()
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
ys: dict[Tree, float] = {node: i for i, node in enumerate(tree.get_leaves())}
|
|
148
|
+
for node in tree.postorder_traversal():
|
|
149
|
+
if node.is_internal():
|
|
150
|
+
ys[node] = sum(ys[child] for child in node.children) / len(node.children)
|
|
151
|
+
|
|
152
|
+
if tree.branch_length is not None:
|
|
153
|
+
origin_date = _depth_to_date(depth=0, calibration_nodes=calibration_nodes)
|
|
154
|
+
ax.hlines( # pyright: ignore
|
|
155
|
+
y=ys[tree],
|
|
156
|
+
xmin=mdates.date2num(origin_date), # pyright: ignore
|
|
157
|
+
xmax=mdates.date2num(xs[tree]), # pyright: ignore
|
|
158
|
+
color=colors[tree],
|
|
159
|
+
)
|
|
160
|
+
for node in tree:
|
|
161
|
+
x1, y1 = xs[node], ys[node]
|
|
162
|
+
for child in node.children:
|
|
163
|
+
x2, y2 = xs[child], ys[child]
|
|
164
|
+
ax.hlines( # pyright: ignore
|
|
165
|
+
y=y2,
|
|
166
|
+
xmin=mdates.date2num(x1), # pyright: ignore
|
|
167
|
+
xmax=mdates.date2num(x2), # pyright: ignore
|
|
168
|
+
color=colors[child],
|
|
169
|
+
)
|
|
170
|
+
ax.vlines(x=mdates.date2num(x1), ymin=y1, ymax=y2, color=colors[child]) # pyright: ignore
|
|
171
|
+
|
|
172
|
+
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
|
|
173
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
|
|
174
|
+
ax.tick_params(axis="x", labelrotation=45) # pyright: ignore
|
|
175
|
+
|
|
176
|
+
ax.set_yticks([]) # pyright: ignore
|
|
177
|
+
return ax
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _init_colored_tree_categorical(
|
|
181
|
+
tree: Tree,
|
|
182
|
+
color_by: str,
|
|
183
|
+
ax: Axes | None = None,
|
|
184
|
+
default_color: Color = "black",
|
|
185
|
+
colormap: str | Colormap = "tab20",
|
|
186
|
+
show_legend: bool = True,
|
|
187
|
+
labels: dict[Any, str] | None = None,
|
|
188
|
+
legend_kwargs: dict[str, Any] | None = None,
|
|
189
|
+
) -> tuple[Axes, dict[Tree, Color]]:
|
|
190
|
+
"""
|
|
191
|
+
Initialize colors for drawing a tree based on categorical metadata.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
tree : Tree
|
|
196
|
+
The phylogenetic tree.
|
|
197
|
+
color_by : str
|
|
198
|
+
The metadata key to color branches by.
|
|
199
|
+
ax : Axes | None, optional
|
|
200
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
201
|
+
default_color : Color, optional
|
|
202
|
+
The color to use for nodes without the specified metadata.
|
|
203
|
+
colormap : str | Colormap, optional
|
|
204
|
+
The colormap to use for coloring categories. Defaults to 'tab20'.
|
|
205
|
+
show_legend : bool, optional
|
|
206
|
+
Whether to display a legend for the categories.
|
|
207
|
+
labels : dict[Any, str] | None, optional
|
|
208
|
+
A mapping from category values to labels for the legend.
|
|
209
|
+
legend_kwargs : dict[str, Any] | None, optional
|
|
210
|
+
Additional keyword arguments to pass to the legend.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
tuple[Axes, dict[Tree, Color]]
|
|
215
|
+
The Axes and a dictionary mapping each node to its assigned color.
|
|
216
|
+
"""
|
|
217
|
+
if ax is None:
|
|
218
|
+
ax = plt.gca()
|
|
219
|
+
|
|
220
|
+
if isinstance(colormap, str):
|
|
221
|
+
colormap = plt.get_cmap(colormap)
|
|
222
|
+
|
|
223
|
+
features = {node: node[color_by] for node in tree if color_by in node.metadata}
|
|
224
|
+
feature_colors = {
|
|
225
|
+
f: mcolors.to_hex(colormap(i)) for i, f in enumerate(set(features.values()))
|
|
226
|
+
}
|
|
227
|
+
colors = {
|
|
228
|
+
node: feature_colors[features[node]] if node in features else default_color
|
|
229
|
+
for node in tree
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if show_legend:
|
|
233
|
+
legend_handles = [
|
|
234
|
+
mpatches.Patch(
|
|
235
|
+
color=feature_colors[f],
|
|
236
|
+
label=str(f) if labels is None else labels[f],
|
|
237
|
+
)
|
|
238
|
+
for f in feature_colors
|
|
239
|
+
]
|
|
240
|
+
if any(color_by not in node.metadata for node in tree):
|
|
241
|
+
legend_handles.append(mpatches.Patch(color=default_color, label="NA"))
|
|
242
|
+
if legend_kwargs is None:
|
|
243
|
+
legend_kwargs = {}
|
|
244
|
+
ax.legend(handles=legend_handles, **legend_kwargs) # pyright: ignore
|
|
245
|
+
|
|
246
|
+
return ax, colors
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def draw_colored_tree_categorical(
|
|
250
|
+
tree: Tree,
|
|
251
|
+
color_by: str,
|
|
252
|
+
ax: Axes | None = None,
|
|
253
|
+
backward_time: bool = False,
|
|
254
|
+
default_color: Color = "black",
|
|
255
|
+
colormap: str | Colormap = "tab20",
|
|
256
|
+
show_legend: bool = True,
|
|
257
|
+
labels: dict[Any, str] | None = None,
|
|
258
|
+
legend_kwargs: dict[str, Any] | None = None,
|
|
259
|
+
):
|
|
260
|
+
"""
|
|
261
|
+
Draw a phylogenetic tree with branches colored based on categorical metadata.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
tree : Tree
|
|
266
|
+
The phylogenetic tree to draw.
|
|
267
|
+
color_by : str
|
|
268
|
+
The metadata key to color branches by.
|
|
269
|
+
ax : Axes | None, optional
|
|
270
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
271
|
+
backward_time : bool, optional
|
|
272
|
+
If True, the x-axis is inverted to represent time going backward.
|
|
273
|
+
default_color : Color, optional
|
|
274
|
+
The color to use for nodes without the specified metadata.
|
|
275
|
+
colormap : str | Colormap, optional
|
|
276
|
+
The colormap to use for coloring categories. Defaults to 'tab20'.
|
|
277
|
+
show_legend : bool, optional
|
|
278
|
+
Whether to display a legend for the categories.
|
|
279
|
+
labels : dict[Any, str] | None, optional
|
|
280
|
+
A mapping from category values to labels for the legend.
|
|
281
|
+
legend_kwargs : dict[str, Any] | None, optional
|
|
282
|
+
Additional keyword arguments to pass to the legend.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
Axes
|
|
287
|
+
The Axes with the drawn colored tree.
|
|
288
|
+
"""
|
|
289
|
+
ax, colors = _init_colored_tree_categorical(
|
|
290
|
+
tree=tree,
|
|
291
|
+
color_by=color_by,
|
|
292
|
+
ax=ax,
|
|
293
|
+
default_color=default_color,
|
|
294
|
+
colormap=colormap,
|
|
295
|
+
show_legend=show_legend,
|
|
296
|
+
labels=labels,
|
|
297
|
+
legend_kwargs=legend_kwargs,
|
|
298
|
+
)
|
|
299
|
+
return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def draw_colored_dated_tree_categorical(
|
|
303
|
+
tree: Tree,
|
|
304
|
+
calibration_nodes: tuple[CalibrationNode, CalibrationNode],
|
|
305
|
+
color_by: str,
|
|
306
|
+
ax: Axes | None = None,
|
|
307
|
+
default_color: Color = "black",
|
|
308
|
+
colormap: str | Colormap = "tab20",
|
|
309
|
+
show_legend: bool = True,
|
|
310
|
+
labels: dict[Any, str] | None = None,
|
|
311
|
+
legend_kwargs: dict[str, Any] | None = None,
|
|
312
|
+
) -> Axes:
|
|
313
|
+
"""
|
|
314
|
+
Draw a dated phylogenetic tree with branches colored based on categorical metadata.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
tree : Tree
|
|
319
|
+
The phylogenetic tree to draw.
|
|
320
|
+
calibration_nodes : tuple[CalibrationNode, CalibrationNode]
|
|
321
|
+
Two calibration nodes defining the mapping from depth to date.
|
|
322
|
+
color_by : str
|
|
323
|
+
The metadata key to color branches by.
|
|
324
|
+
ax : Axes | None, optional
|
|
325
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
326
|
+
default_color : Color, optional
|
|
327
|
+
The color to use for nodes without the specified metadata.
|
|
328
|
+
colormap : str | Colormap, optional
|
|
329
|
+
The colormap to use for coloring categories. Defaults to 'tab20'.
|
|
330
|
+
show_legend : bool, optional
|
|
331
|
+
Whether to display a legend for the categories.
|
|
332
|
+
labels : dict[Any, str] | None, optional
|
|
333
|
+
A mapping from category values to labels for the legend.
|
|
334
|
+
legend_kwargs : dict[str, Any] | None, optional
|
|
335
|
+
Additional keyword arguments to pass to the legend.
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
Axes
|
|
340
|
+
The Axes with the drawn colored dated tree.
|
|
341
|
+
"""
|
|
342
|
+
ax, colors = _init_colored_tree_categorical(
|
|
343
|
+
tree=tree,
|
|
344
|
+
color_by=color_by,
|
|
345
|
+
ax=ax,
|
|
346
|
+
default_color=default_color,
|
|
347
|
+
colormap=colormap,
|
|
348
|
+
show_legend=show_legend,
|
|
349
|
+
labels=labels,
|
|
350
|
+
legend_kwargs=legend_kwargs,
|
|
351
|
+
)
|
|
352
|
+
return draw_dated_tree(
|
|
353
|
+
tree=tree, calibration_nodes=calibration_nodes, ax=ax, colors=colors
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
@overload
|
|
358
|
+
def _init_colored_tree_continuous(
|
|
359
|
+
tree: Tree,
|
|
360
|
+
color_by: str,
|
|
361
|
+
ax: Axes | None = ...,
|
|
362
|
+
default_color: Color = ...,
|
|
363
|
+
colormap: str | Colormap = ...,
|
|
364
|
+
vmin: float | None = ...,
|
|
365
|
+
vmax: float | None = ...,
|
|
366
|
+
*,
|
|
367
|
+
show_hist: Literal[False],
|
|
368
|
+
hist_kwargs: dict[str, Any] | None = ...,
|
|
369
|
+
hist_axes_kwargs: dict[str, Any] | None = ...,
|
|
370
|
+
) -> tuple[Axes, dict[Tree, Color]]: ...
|
|
371
|
+
@overload
|
|
372
|
+
def _init_colored_tree_continuous(
|
|
373
|
+
tree: Tree,
|
|
374
|
+
color_by: str,
|
|
375
|
+
ax: Axes | None = ...,
|
|
376
|
+
default_color: Color = ...,
|
|
377
|
+
colormap: str | Colormap = ...,
|
|
378
|
+
vmin: float | None = ...,
|
|
379
|
+
vmax: float | None = ...,
|
|
380
|
+
*,
|
|
381
|
+
show_hist: Literal[True],
|
|
382
|
+
hist_kwargs: dict[str, Any] | None = ...,
|
|
383
|
+
hist_axes_kwargs: dict[str, Any] | None = ...,
|
|
384
|
+
) -> tuple[Axes, dict[Tree, Color], Axes]: ...
|
|
385
|
+
def _init_colored_tree_continuous(
|
|
386
|
+
tree: Tree,
|
|
387
|
+
color_by: str,
|
|
388
|
+
ax: Axes | None = None,
|
|
389
|
+
default_color: Color = "black",
|
|
390
|
+
colormap: str | Colormap = "viridis",
|
|
391
|
+
vmin: float | None = None,
|
|
392
|
+
vmax: float | None = None,
|
|
393
|
+
*,
|
|
394
|
+
show_hist: bool = True,
|
|
395
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
396
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
397
|
+
) -> tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]:
|
|
398
|
+
"""
|
|
399
|
+
Initialize colors for drawing a tree based on continuous metadata.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
tree : Tree
|
|
404
|
+
The phylogenetic tree.
|
|
405
|
+
color_by : str
|
|
406
|
+
The metadata key to color branches by.
|
|
407
|
+
ax : Axes | None, optional
|
|
408
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
409
|
+
default_color : Color, optional
|
|
410
|
+
The color to use for nodes without the specified metadata.
|
|
411
|
+
colormap : str | Colormap, optional
|
|
412
|
+
The colormap to use for coloring continuous values. Defaults to 'viridis'.
|
|
413
|
+
vmin : float | None, optional
|
|
414
|
+
The minimum value for normalization. If None, uses the minimum of the data.
|
|
415
|
+
vmax : float | None, optional
|
|
416
|
+
The maximum value for normalization. If None, uses the maximum of the data.
|
|
417
|
+
show_hist : bool, optional
|
|
418
|
+
Whether to display a histogram of the continuous values.
|
|
419
|
+
hist_kwargs : dict[str, Any] | None, optional
|
|
420
|
+
Additional keyword arguments to pass to the histogram.
|
|
421
|
+
hist_axes_kwargs : dict[str, Any] | None, optional
|
|
422
|
+
Additional keyword arguments to define the histogram Axes.
|
|
423
|
+
|
|
424
|
+
Returns
|
|
425
|
+
-------
|
|
426
|
+
tuple[Axes, dict[Tree, Color]] | tuple[Axes, dict[Tree, Color], Axes]
|
|
427
|
+
The Axes, a dictionary mapping each node to its assigned color,
|
|
428
|
+
and optionally the histogram Axes if `show_hist` is True.
|
|
429
|
+
"""
|
|
430
|
+
if ax is None:
|
|
431
|
+
ax = plt.gca()
|
|
432
|
+
|
|
433
|
+
if isinstance(colormap, str):
|
|
434
|
+
colormap = plt.get_cmap(colormap)
|
|
435
|
+
|
|
436
|
+
features = {node: node[color_by] for node in tree if color_by in node.metadata}
|
|
437
|
+
values = list(features.values())
|
|
438
|
+
vmin = min(values) if vmin is None else vmin
|
|
439
|
+
vmax = max(values) if vmax is None else vmax
|
|
440
|
+
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
|
|
441
|
+
colors = {
|
|
442
|
+
node: colormap(norm(float(features[node])))
|
|
443
|
+
if node in features
|
|
444
|
+
else default_color
|
|
445
|
+
for node in tree
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
if show_hist:
|
|
449
|
+
default_hist_axes_kwargs = {"width": "25%", "height": "25%"}
|
|
450
|
+
if hist_axes_kwargs is not None:
|
|
451
|
+
default_hist_axes_kwargs.update(hist_axes_kwargs)
|
|
452
|
+
hist_ax = inset_axes(ax, **default_hist_axes_kwargs) # pyright: ignore
|
|
453
|
+
|
|
454
|
+
hist_kwargs = {} if hist_kwargs is None else hist_kwargs
|
|
455
|
+
_, bins, patches = hist_ax.hist(values, **hist_kwargs) # pyright: ignore
|
|
456
|
+
|
|
457
|
+
for patch, b0, b1 in zip(patches, bins[:-1], bins[1:]): # pyright: ignore
|
|
458
|
+
midpoint = (b0 + b1) / 2 # pyright: ignore
|
|
459
|
+
patch.set_facecolor(colormap(norm(midpoint))) # pyright: ignore
|
|
460
|
+
return ax, colors, hist_ax # pyright: ignore
|
|
461
|
+
|
|
462
|
+
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
|
|
463
|
+
ax.get_figure().colorbar(sm, ax=ax) # pyright: ignore
|
|
464
|
+
return ax, colors
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@overload
|
|
468
|
+
def draw_colored_tree_continuous(
|
|
469
|
+
tree: Tree,
|
|
470
|
+
color_by: str,
|
|
471
|
+
ax: Axes | None = None,
|
|
472
|
+
backward_time: bool = False,
|
|
473
|
+
default_color: Color = "black",
|
|
474
|
+
colormap: str | Colormap = "viridis",
|
|
475
|
+
vmin: float | None = None,
|
|
476
|
+
vmax: float | None = None,
|
|
477
|
+
*,
|
|
478
|
+
show_hist: Literal[False],
|
|
479
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
480
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
481
|
+
) -> Axes: ...
|
|
482
|
+
@overload
|
|
483
|
+
def draw_colored_tree_continuous(
|
|
484
|
+
tree: Tree,
|
|
485
|
+
color_by: str,
|
|
486
|
+
ax: Axes | None = None,
|
|
487
|
+
backward_time: bool = False,
|
|
488
|
+
default_color: Color = "black",
|
|
489
|
+
colormap: str | Colormap = "viridis",
|
|
490
|
+
vmin: float | None = None,
|
|
491
|
+
vmax: float | None = None,
|
|
492
|
+
*,
|
|
493
|
+
show_hist: Literal[True],
|
|
494
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
495
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
496
|
+
) -> tuple[Axes, Axes]: ...
|
|
497
|
+
def draw_colored_tree_continuous(
|
|
498
|
+
tree: Tree,
|
|
499
|
+
color_by: str,
|
|
500
|
+
ax: Axes | None = None,
|
|
501
|
+
backward_time: bool = False,
|
|
502
|
+
default_color: Color = "black",
|
|
503
|
+
colormap: str | Colormap = "viridis",
|
|
504
|
+
vmin: float | None = None,
|
|
505
|
+
vmax: float | None = None,
|
|
506
|
+
show_hist: bool = True,
|
|
507
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
508
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
509
|
+
) -> Axes | tuple[Axes, Axes]:
|
|
510
|
+
"""
|
|
511
|
+
Draw a phylogenetic tree with branches colored based on continuous metadata.
|
|
512
|
+
|
|
513
|
+
Parameters
|
|
514
|
+
----------
|
|
515
|
+
tree : Tree
|
|
516
|
+
The phylogenetic tree to draw.
|
|
517
|
+
color_by : str
|
|
518
|
+
The metadata key to color branches by.
|
|
519
|
+
ax : Axes | None, optional
|
|
520
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
521
|
+
backward_time : bool, optional
|
|
522
|
+
If True, the x-axis is inverted to represent time going backward.
|
|
523
|
+
default_color : Color, optional
|
|
524
|
+
The color to use for nodes without the specified metadata.
|
|
525
|
+
colormap : str | Colormap, optional
|
|
526
|
+
The colormap to use for coloring continuous values. Defaults to 'viridis'.
|
|
527
|
+
vmin : float | None, optional
|
|
528
|
+
The minimum value for normalization. If None, uses the minimum of the data.
|
|
529
|
+
vmax : float | None, optional
|
|
530
|
+
The maximum value for normalization. If None, uses the maximum of the data.
|
|
531
|
+
show_hist : bool, optional
|
|
532
|
+
Whether to display a histogram of the continuous values.
|
|
533
|
+
hist_kwargs : dict[str, Any] | None, optional
|
|
534
|
+
Additional keyword arguments to pass to the histogram.
|
|
535
|
+
hist_axes_kwargs : dict[str, Any] | None, optional
|
|
536
|
+
Additional keyword arguments to define the histogram Axes.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
Axes | tuple[Axes, Axes]
|
|
541
|
+
The Axes with the drawn colored tree,
|
|
542
|
+
and optionally the histogram Axes if `show_hist` is True.
|
|
543
|
+
"""
|
|
544
|
+
if show_hist:
|
|
545
|
+
ax, colors, hist_ax = _init_colored_tree_continuous(
|
|
546
|
+
tree=tree,
|
|
547
|
+
color_by=color_by,
|
|
548
|
+
ax=ax,
|
|
549
|
+
default_color=default_color,
|
|
550
|
+
colormap=colormap,
|
|
551
|
+
vmin=vmin,
|
|
552
|
+
vmax=vmax,
|
|
553
|
+
show_hist=show_hist,
|
|
554
|
+
hist_kwargs=hist_kwargs,
|
|
555
|
+
hist_axes_kwargs=hist_axes_kwargs,
|
|
556
|
+
)
|
|
557
|
+
return draw_tree(
|
|
558
|
+
tree=tree, ax=ax, colors=colors, backward_time=backward_time
|
|
559
|
+
), hist_ax
|
|
560
|
+
|
|
561
|
+
ax, colors = _init_colored_tree_continuous(
|
|
562
|
+
tree=tree,
|
|
563
|
+
color_by=color_by,
|
|
564
|
+
ax=ax,
|
|
565
|
+
default_color=default_color,
|
|
566
|
+
colormap=colormap,
|
|
567
|
+
vmin=vmin,
|
|
568
|
+
vmax=vmax,
|
|
569
|
+
show_hist=show_hist,
|
|
570
|
+
hist_kwargs=hist_kwargs,
|
|
571
|
+
hist_axes_kwargs=hist_axes_kwargs,
|
|
572
|
+
)
|
|
573
|
+
return draw_tree(tree=tree, ax=ax, colors=colors, backward_time=backward_time)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@overload
|
|
577
|
+
def draw_colored_dated_tree_continuous(
|
|
578
|
+
tree: Tree,
|
|
579
|
+
calibration_nodes: tuple[CalibrationNode, CalibrationNode],
|
|
580
|
+
color_by: str,
|
|
581
|
+
ax: Axes | None = None,
|
|
582
|
+
default_color: Color = "black",
|
|
583
|
+
colormap: str | Colormap = "viridis",
|
|
584
|
+
vmin: float | None = None,
|
|
585
|
+
vmax: float | None = None,
|
|
586
|
+
*,
|
|
587
|
+
show_hist: Literal[False],
|
|
588
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
589
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
590
|
+
) -> Axes: ...
|
|
591
|
+
@overload
|
|
592
|
+
def draw_colored_dated_tree_continuous(
|
|
593
|
+
tree: Tree,
|
|
594
|
+
calibration_nodes: tuple[CalibrationNode, CalibrationNode],
|
|
595
|
+
color_by: str,
|
|
596
|
+
ax: Axes | None = None,
|
|
597
|
+
default_color: Color = "black",
|
|
598
|
+
colormap: str | Colormap = "viridis",
|
|
599
|
+
vmin: float | None = None,
|
|
600
|
+
vmax: float | None = None,
|
|
601
|
+
*,
|
|
602
|
+
show_hist: Literal[True],
|
|
603
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
604
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
605
|
+
) -> tuple[Axes, Axes]: ...
|
|
606
|
+
def draw_colored_dated_tree_continuous(
|
|
607
|
+
tree: Tree,
|
|
608
|
+
calibration_nodes: tuple[CalibrationNode, CalibrationNode],
|
|
609
|
+
color_by: str,
|
|
610
|
+
ax: Axes | None = None,
|
|
611
|
+
default_color: Color = "black",
|
|
612
|
+
colormap: str | Colormap = "viridis",
|
|
613
|
+
vmin: float | None = None,
|
|
614
|
+
vmax: float | None = None,
|
|
615
|
+
show_hist: bool = True,
|
|
616
|
+
hist_kwargs: dict[str, Any] | None = None,
|
|
617
|
+
hist_axes_kwargs: dict[str, Any] | None = None,
|
|
618
|
+
) -> Axes | tuple[Axes, Axes]:
|
|
619
|
+
"""
|
|
620
|
+
Draw a dated phylogenetic tree with branches colored based on continuous metadata.
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
tree : Tree
|
|
625
|
+
The phylogenetic tree to draw.
|
|
626
|
+
calibration_nodes : tuple[CalibrationNode, CalibrationNode]
|
|
627
|
+
Two calibration nodes defining the mapping from depth to date.
|
|
628
|
+
color_by : str
|
|
629
|
+
The metadata key to color branches by.
|
|
630
|
+
ax : Axes | None, optional
|
|
631
|
+
The matplotlib Axes to draw on. If None, uses the current Axes.
|
|
632
|
+
default_color : Color, optional
|
|
633
|
+
The color to use for nodes without the specified metadata.
|
|
634
|
+
colormap : str | Colormap, optional
|
|
635
|
+
The colormap to use for coloring continuous values. Defaults to 'viridis'.
|
|
636
|
+
vmin : float | None, optional
|
|
637
|
+
The minimum value for normalization. If None, uses the minimum of the data.
|
|
638
|
+
vmax : float | None, optional
|
|
639
|
+
The maximum value for normalization. If None, uses the maximum of the data.
|
|
640
|
+
show_hist : bool, optional
|
|
641
|
+
Whether to display a histogram of the continuous values.
|
|
642
|
+
hist_kwargs : dict[str, Any] | None, optional
|
|
643
|
+
Additional keyword arguments to pass to the histogram.
|
|
644
|
+
hist_axes_kwargs : dict[str, Any] | None, optional
|
|
645
|
+
Additional keyword arguments to define the histogram Axes.
|
|
646
|
+
|
|
647
|
+
Returns
|
|
648
|
+
-------
|
|
649
|
+
Axes | tuple[Axes, Axes]
|
|
650
|
+
The Axes with the drawn colored dated tree,
|
|
651
|
+
and optionally the histogram Axes if `show_hist` is True.
|
|
652
|
+
"""
|
|
653
|
+
if show_hist:
|
|
654
|
+
ax, colors, hist_ax = _init_colored_tree_continuous(
|
|
655
|
+
tree=tree,
|
|
656
|
+
color_by=color_by,
|
|
657
|
+
ax=ax,
|
|
658
|
+
default_color=default_color,
|
|
659
|
+
colormap=colormap,
|
|
660
|
+
vmin=vmin,
|
|
661
|
+
vmax=vmax,
|
|
662
|
+
show_hist=show_hist,
|
|
663
|
+
hist_kwargs=hist_kwargs,
|
|
664
|
+
hist_axes_kwargs=hist_axes_kwargs,
|
|
665
|
+
)
|
|
666
|
+
return draw_dated_tree(
|
|
667
|
+
tree=tree,
|
|
668
|
+
calibration_nodes=calibration_nodes,
|
|
669
|
+
ax=ax,
|
|
670
|
+
colors=colors,
|
|
671
|
+
), hist_ax
|
|
672
|
+
|
|
673
|
+
ax, colors = _init_colored_tree_continuous(
|
|
674
|
+
tree=tree,
|
|
675
|
+
color_by=color_by,
|
|
676
|
+
ax=ax,
|
|
677
|
+
default_color=default_color,
|
|
678
|
+
colormap=colormap,
|
|
679
|
+
vmin=vmin,
|
|
680
|
+
vmax=vmax,
|
|
681
|
+
show_hist=show_hist,
|
|
682
|
+
hist_kwargs=hist_kwargs,
|
|
683
|
+
hist_axes_kwargs=hist_axes_kwargs,
|
|
684
|
+
)
|
|
685
|
+
return draw_dated_tree(
|
|
686
|
+
tree=tree,
|
|
687
|
+
calibration_nodes=calibration_nodes,
|
|
688
|
+
ax=ax,
|
|
689
|
+
colors=colors,
|
|
690
|
+
)
|