tskit 1.0.1__cp314-cp314-macosx_10_15_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _tskit.cpython-314-darwin.so +0 -0
- tskit/__init__.py +92 -0
- tskit/__main__.py +4 -0
- tskit/_version.py +4 -0
- tskit/cli.py +273 -0
- tskit/combinatorics.py +1522 -0
- tskit/drawing.py +2809 -0
- tskit/exceptions.py +70 -0
- tskit/genotypes.py +410 -0
- tskit/intervals.py +601 -0
- tskit/jit/__init__.py +0 -0
- tskit/jit/numba.py +674 -0
- tskit/metadata.py +1147 -0
- tskit/provenance.py +150 -0
- tskit/provenance.schema.json +72 -0
- tskit/stats.py +165 -0
- tskit/tables.py +4858 -0
- tskit/text_formats.py +456 -0
- tskit/trees.py +11457 -0
- tskit/util.py +901 -0
- tskit/vcf.py +219 -0
- tskit-1.0.1.dist-info/METADATA +105 -0
- tskit-1.0.1.dist-info/RECORD +27 -0
- tskit-1.0.1.dist-info/WHEEL +5 -0
- tskit-1.0.1.dist-info/entry_points.txt +2 -0
- tskit-1.0.1.dist-info/licenses/LICENSE +21 -0
- tskit-1.0.1.dist-info/top_level.txt +2 -0
tskit/drawing.py
ADDED
|
@@ -0,0 +1,2809 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2018-2025 Tskit Developers
|
|
4
|
+
# Copyright (c) 2015-2017 University of Oxford
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
"""
|
|
24
|
+
Module responsible for visualisations.
|
|
25
|
+
"""
|
|
26
|
+
import collections
|
|
27
|
+
import itertools
|
|
28
|
+
import logging
|
|
29
|
+
import math
|
|
30
|
+
import numbers
|
|
31
|
+
import operator
|
|
32
|
+
import warnings
|
|
33
|
+
import xml.dom.minidom
|
|
34
|
+
from collections.abc import Mapping
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
|
|
37
|
+
import numpy as np
|
|
38
|
+
|
|
39
|
+
import tskit
|
|
40
|
+
import tskit.util as util
|
|
41
|
+
from _tskit import NODE_IS_SAMPLE
|
|
42
|
+
from _tskit import NULL
|
|
43
|
+
|
|
44
|
+
LEFT = "left"
|
|
45
|
+
RIGHT = "right"
|
|
46
|
+
TOP = "top"
|
|
47
|
+
BOTTOM = "bottom"
|
|
48
|
+
|
|
49
|
+
# constants for whether to plot a tree in a tree sequence
|
|
50
|
+
OMIT = 1
|
|
51
|
+
LEFT_CLIP = 2
|
|
52
|
+
RIGHT_CLIP = 4
|
|
53
|
+
OMIT_MIDDLE = 8
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Minimal SVG generation module to replace svgwrite for tskit visualization.
|
|
57
|
+
# This implementation provides only the functionality needed for the visualization
|
|
58
|
+
# code while maintaining the same API as svgwrite.
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Element:
|
|
62
|
+
def __init__(self, tag, **kwargs):
|
|
63
|
+
self.tag = tag
|
|
64
|
+
self.attrs = {}
|
|
65
|
+
self.children = []
|
|
66
|
+
|
|
67
|
+
# Process kwargs in alphabetical order
|
|
68
|
+
for key in sorted(kwargs.keys()):
|
|
69
|
+
value = kwargs[key]
|
|
70
|
+
# Handle class_ special case for class attribute
|
|
71
|
+
if key.endswith("_"):
|
|
72
|
+
key = key[:-1]
|
|
73
|
+
key = key.replace("_", "-")
|
|
74
|
+
self.attrs[key] = value
|
|
75
|
+
|
|
76
|
+
def __getitem__(self, key):
|
|
77
|
+
return self.attrs.get(key, "")
|
|
78
|
+
|
|
79
|
+
def __setitem__(self, key, value):
|
|
80
|
+
self.attrs[key] = value
|
|
81
|
+
|
|
82
|
+
def add(self, child):
|
|
83
|
+
self.children.append(child)
|
|
84
|
+
return child
|
|
85
|
+
|
|
86
|
+
def set_desc(self, **kwargs):
|
|
87
|
+
if "title" in kwargs:
|
|
88
|
+
title_elem = Element("title")
|
|
89
|
+
title_elem.children.append(kwargs["title"])
|
|
90
|
+
self.children.append(title_elem)
|
|
91
|
+
return self
|
|
92
|
+
|
|
93
|
+
def _attr_str(self):
|
|
94
|
+
result = []
|
|
95
|
+
for key, value in self.attrs.items():
|
|
96
|
+
if isinstance(value, (list, tuple)):
|
|
97
|
+
# Handle points lists (for polygon/polyline)
|
|
98
|
+
if key == "points":
|
|
99
|
+
points_str = " ".join(f"{x},{y}" for x, y in value)
|
|
100
|
+
result.append(f'{key}="{points_str}"')
|
|
101
|
+
else:
|
|
102
|
+
result.append(f'{key}="{" ".join(map(str, value))}"')
|
|
103
|
+
else:
|
|
104
|
+
result.append(f'{key}="{value}"')
|
|
105
|
+
return " ".join(result)
|
|
106
|
+
|
|
107
|
+
def tostring(self):
|
|
108
|
+
stack = [(self, False)]
|
|
109
|
+
result = []
|
|
110
|
+
|
|
111
|
+
while stack:
|
|
112
|
+
elem, is_closing_tag = stack.pop()
|
|
113
|
+
if is_closing_tag:
|
|
114
|
+
result.append(f"</{elem.tag}>")
|
|
115
|
+
continue
|
|
116
|
+
attr_str = elem._attr_str()
|
|
117
|
+
start = f"<{elem.tag}"
|
|
118
|
+
if attr_str:
|
|
119
|
+
start += f" {attr_str}"
|
|
120
|
+
if not elem.children:
|
|
121
|
+
result.append(f"{start}/>")
|
|
122
|
+
else:
|
|
123
|
+
result.append(f"{start}>")
|
|
124
|
+
stack.append((elem, True))
|
|
125
|
+
for child in reversed(elem.children):
|
|
126
|
+
if isinstance(child, Element):
|
|
127
|
+
stack.append((child, False))
|
|
128
|
+
else:
|
|
129
|
+
result.append(str(child))
|
|
130
|
+
|
|
131
|
+
return "".join(result)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class Drawing:
|
|
135
|
+
def __init__(self, size=None, **kwargs):
|
|
136
|
+
kwargs = {
|
|
137
|
+
"version": "1.1",
|
|
138
|
+
"xmlns": "http://www.w3.org/2000/svg",
|
|
139
|
+
"xmlns:ev": "http://www.w3.org/2001/xml-events",
|
|
140
|
+
"xmlns:xlink": "http://www.w3.org/1999/xlink",
|
|
141
|
+
"baseProfile": "full",
|
|
142
|
+
**kwargs,
|
|
143
|
+
}
|
|
144
|
+
if size is not None:
|
|
145
|
+
kwargs["width"] = size[0]
|
|
146
|
+
kwargs["height"] = size[1]
|
|
147
|
+
|
|
148
|
+
self.root = Element("svg", **kwargs)
|
|
149
|
+
self.root.add("") # First root elem is a blank preamble
|
|
150
|
+
self.defs = Element("defs")
|
|
151
|
+
self.root.add(self.defs)
|
|
152
|
+
|
|
153
|
+
def add(self, element):
|
|
154
|
+
return self.root.add(element)
|
|
155
|
+
|
|
156
|
+
def g(self, **kwargs):
|
|
157
|
+
return Element("g", **kwargs)
|
|
158
|
+
|
|
159
|
+
def rect(self, insert=None, size=None, **kwargs):
|
|
160
|
+
if insert:
|
|
161
|
+
kwargs["x"] = insert[0]
|
|
162
|
+
kwargs["y"] = insert[1]
|
|
163
|
+
if size:
|
|
164
|
+
kwargs["width"] = size[0]
|
|
165
|
+
kwargs["height"] = size[1]
|
|
166
|
+
return Element("rect", **kwargs)
|
|
167
|
+
|
|
168
|
+
def circle(self, center=None, r=None, **kwargs):
|
|
169
|
+
if center:
|
|
170
|
+
kwargs["cx"] = center[0]
|
|
171
|
+
kwargs["cy"] = center[1]
|
|
172
|
+
if r:
|
|
173
|
+
kwargs["r"] = r
|
|
174
|
+
return Element("circle", **kwargs)
|
|
175
|
+
|
|
176
|
+
def line(self, start=None, end=None, **kwargs):
|
|
177
|
+
if start:
|
|
178
|
+
kwargs["x1"] = start[0]
|
|
179
|
+
kwargs["y1"] = start[1]
|
|
180
|
+
else:
|
|
181
|
+
kwargs["x1"] = 0
|
|
182
|
+
kwargs["y1"] = 0
|
|
183
|
+
if end:
|
|
184
|
+
kwargs["x2"] = end[0]
|
|
185
|
+
kwargs["y2"] = end[1]
|
|
186
|
+
else:
|
|
187
|
+
kwargs["x2"] = 0 # pragma: not covered
|
|
188
|
+
kwargs["y2"] = 0 # pragma: not covered
|
|
189
|
+
return Element("line", **kwargs)
|
|
190
|
+
|
|
191
|
+
def polyline(self, points=None, **kwargs):
|
|
192
|
+
if points:
|
|
193
|
+
kwargs["points"] = points
|
|
194
|
+
return Element("polyline", **kwargs)
|
|
195
|
+
|
|
196
|
+
def polygon(self, points=None, **kwargs):
|
|
197
|
+
if points:
|
|
198
|
+
kwargs["points"] = points
|
|
199
|
+
return Element("polygon", **kwargs)
|
|
200
|
+
|
|
201
|
+
def path(self, d=None, **kwargs):
|
|
202
|
+
if isinstance(d, list):
|
|
203
|
+
# Convert path commands from tuples to string
|
|
204
|
+
path_str = ""
|
|
205
|
+
for cmd in d:
|
|
206
|
+
if isinstance(cmd, tuple) and len(cmd) >= 2:
|
|
207
|
+
cmd_letter = cmd[0]
|
|
208
|
+
# Handle nested tuples by flattening
|
|
209
|
+
params = []
|
|
210
|
+
for param in cmd[1:]:
|
|
211
|
+
if isinstance(param, tuple):
|
|
212
|
+
# Flatten tuple coordinates
|
|
213
|
+
params.extend(str(p) for p in param)
|
|
214
|
+
else:
|
|
215
|
+
params.append(str(param))
|
|
216
|
+
path_str += f"{cmd_letter} {' '.join(params)} "
|
|
217
|
+
kwargs["d"] = path_str.strip()
|
|
218
|
+
elif d:
|
|
219
|
+
kwargs["d"] = d
|
|
220
|
+
return Element("path", **kwargs)
|
|
221
|
+
|
|
222
|
+
def text(self, text=None, **kwargs):
|
|
223
|
+
elem = Element("text", **kwargs)
|
|
224
|
+
if text:
|
|
225
|
+
elem.children.append(text)
|
|
226
|
+
return elem
|
|
227
|
+
|
|
228
|
+
def style(self, content):
|
|
229
|
+
elem = Element("style", type="text/css")
|
|
230
|
+
if content:
|
|
231
|
+
# Use CDATA to avoid having to escape special characters in CSS
|
|
232
|
+
elem.children.append(f"<![CDATA[{content}]]>")
|
|
233
|
+
return elem
|
|
234
|
+
|
|
235
|
+
def tostring(self, pretty=False):
|
|
236
|
+
if pretty:
|
|
237
|
+
return xml.dom.minidom.parseString(self.root.tostring()).toprettyxml()
|
|
238
|
+
return self.root.tostring()
|
|
239
|
+
|
|
240
|
+
def saveas(self, path, pretty=False):
|
|
241
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
242
|
+
f.write(self.tostring(pretty=pretty))
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@dataclass
|
|
246
|
+
class Offsets:
|
|
247
|
+
"Used when x_lim set, and displayed ts has been cut down by keep_intervals"
|
|
248
|
+
|
|
249
|
+
tree: int = 0
|
|
250
|
+
site: int = 0
|
|
251
|
+
mutation: int = 0
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@dataclass(frozen=True)
|
|
255
|
+
class Timescaling:
|
|
256
|
+
"Class used to transform the time axis"
|
|
257
|
+
|
|
258
|
+
max_time: float
|
|
259
|
+
min_time: float
|
|
260
|
+
plot_min: float
|
|
261
|
+
plot_range: float
|
|
262
|
+
use_log_transform: bool
|
|
263
|
+
|
|
264
|
+
def __post_init__(self):
|
|
265
|
+
if self.plot_range < 0:
|
|
266
|
+
raise ValueError("Image size too small to allow space to plot tree")
|
|
267
|
+
if self.use_log_transform:
|
|
268
|
+
if self.min_time < 0:
|
|
269
|
+
raise ValueError("Cannot use a log scale if there are negative times")
|
|
270
|
+
super().__setattr__("transform", self.log_transform)
|
|
271
|
+
else:
|
|
272
|
+
super().__setattr__("transform", self.linear_transform)
|
|
273
|
+
|
|
274
|
+
def log_transform(self, y):
|
|
275
|
+
"Standard log transform but allowing for values of 0 by adding 1"
|
|
276
|
+
delta = 1 if self.min_time == 0 else 0
|
|
277
|
+
log_max = np.log(self.max_time + delta)
|
|
278
|
+
log_min = np.log(self.min_time + delta)
|
|
279
|
+
y_scale = self.plot_range / (log_max - log_min)
|
|
280
|
+
return self.plot_min - (np.log(y + delta) - log_min) * y_scale
|
|
281
|
+
|
|
282
|
+
def linear_transform(self, y):
|
|
283
|
+
y_scale = self.plot_range / (self.max_time - self.min_time)
|
|
284
|
+
return self.plot_min - (y - self.min_time) * y_scale
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class SVGString(str):
|
|
288
|
+
"A string containing an SVG representation"
|
|
289
|
+
|
|
290
|
+
def _repr_svg_(self):
|
|
291
|
+
"""
|
|
292
|
+
Simply return the SVG string: called by jupyter notebooks to render trees.
|
|
293
|
+
"""
|
|
294
|
+
return self
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def check_orientation(orientation):
|
|
298
|
+
if orientation is None:
|
|
299
|
+
orientation = TOP
|
|
300
|
+
else:
|
|
301
|
+
orientation = orientation.lower()
|
|
302
|
+
orientations = [LEFT, RIGHT, TOP, BOTTOM]
|
|
303
|
+
if orientation not in orientations:
|
|
304
|
+
raise ValueError(f"Unknown orientiation: choose from {orientations}")
|
|
305
|
+
return orientation
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def check_max_time(max_time, allow_numeric=True):
|
|
309
|
+
if max_time is None:
|
|
310
|
+
max_time = "tree"
|
|
311
|
+
is_numeric = isinstance(max_time, numbers.Real)
|
|
312
|
+
if max_time not in ["tree", "ts"] and not allow_numeric:
|
|
313
|
+
raise ValueError("max_time must be 'tree' or 'ts'")
|
|
314
|
+
if max_time not in ["tree", "ts"] and (allow_numeric and not is_numeric):
|
|
315
|
+
raise ValueError("max_time must be a numeric value or one of 'tree' or 'ts'")
|
|
316
|
+
return max_time
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def check_min_time(min_time, allow_numeric=True):
|
|
320
|
+
if min_time is None:
|
|
321
|
+
min_time = "tree"
|
|
322
|
+
if allow_numeric:
|
|
323
|
+
is_numeric = isinstance(min_time, numbers.Real)
|
|
324
|
+
if min_time not in ["tree", "ts"] and not is_numeric:
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"min_time must be a numeric value or one of 'tree' or 'ts'"
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
if min_time not in ["tree", "ts"]:
|
|
330
|
+
raise ValueError("min_time must be 'tree' or 'ts'")
|
|
331
|
+
return min_time
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def check_time_scale(time_scale):
|
|
335
|
+
if time_scale is None:
|
|
336
|
+
time_scale = "time"
|
|
337
|
+
if time_scale not in ["time", "log_time", "rank"]:
|
|
338
|
+
raise ValueError("time_scale must be 'time', 'log_time' or 'rank'")
|
|
339
|
+
return time_scale
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def check_format(format): # noqa A002
|
|
343
|
+
if format is None:
|
|
344
|
+
format = "SVG" # noqa A001
|
|
345
|
+
fmt = format.lower()
|
|
346
|
+
supported_formats = ["svg", "ascii", "unicode"]
|
|
347
|
+
if fmt not in supported_formats:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
"Unknown format '{}'. Supported formats are {}".format(
|
|
350
|
+
format, supported_formats
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
return fmt
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def check_order(order):
|
|
357
|
+
"""
|
|
358
|
+
Checks the specified drawing order is valid and returns the corresponding
|
|
359
|
+
tree traversal order.
|
|
360
|
+
"""
|
|
361
|
+
if order is None:
|
|
362
|
+
order = "minlex"
|
|
363
|
+
traversal_orders = {
|
|
364
|
+
"minlex": "minlex_postorder",
|
|
365
|
+
"tree": "postorder",
|
|
366
|
+
}
|
|
367
|
+
# Silently accept a tree traversal order as a valid order, so we can
|
|
368
|
+
# call this check twice if necessary
|
|
369
|
+
if order in traversal_orders.values():
|
|
370
|
+
return order
|
|
371
|
+
if order not in traversal_orders:
|
|
372
|
+
raise ValueError(
|
|
373
|
+
f"Unknown display order '{order}'. "
|
|
374
|
+
f"Supported orders are {list(traversal_orders.keys())}"
|
|
375
|
+
)
|
|
376
|
+
return traversal_orders[order]
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def check_x_scale(x_scale):
|
|
380
|
+
"""
|
|
381
|
+
Checks the specified x_scale is valid and sets default if None
|
|
382
|
+
"""
|
|
383
|
+
if x_scale is None:
|
|
384
|
+
x_scale = "physical"
|
|
385
|
+
x_scales = ["physical", "treewise"]
|
|
386
|
+
if x_scale not in x_scales:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"Unknown display x_scale '{x_scale}'. " f"Supported orders are {x_scales}"
|
|
389
|
+
)
|
|
390
|
+
return x_scale
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def check_x_lim(x_lim, max_x):
|
|
394
|
+
"""
|
|
395
|
+
Checks the specified x_limits are valid and sets default if None.
|
|
396
|
+
"""
|
|
397
|
+
if x_lim is None:
|
|
398
|
+
x_lim = (None, None)
|
|
399
|
+
if len(x_lim) != 2:
|
|
400
|
+
raise ValueError("The x_lim parameter must be a list of length 2, or None")
|
|
401
|
+
try:
|
|
402
|
+
if x_lim[0] is not None and x_lim[0] < 0:
|
|
403
|
+
raise ValueError("x_lim[0] cannot be negative")
|
|
404
|
+
if x_lim[1] is not None and x_lim[1] > max_x:
|
|
405
|
+
raise ValueError("x_lim[1] cannot be greater than the sequence length")
|
|
406
|
+
if x_lim[0] is not None and x_lim[1] is not None and x_lim[0] >= x_lim[1]:
|
|
407
|
+
raise ValueError("x_lim[0] must be less than x_lim[1]")
|
|
408
|
+
except TypeError:
|
|
409
|
+
raise TypeError("x_lim parameters must be numeric")
|
|
410
|
+
return x_lim
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def check_y_axis(y_axis):
|
|
414
|
+
"""
|
|
415
|
+
Checks the specified y_axis is valid and sets default if None.
|
|
416
|
+
"""
|
|
417
|
+
if y_axis is None:
|
|
418
|
+
y_axis = False
|
|
419
|
+
if y_axis is True:
|
|
420
|
+
y_axis = "left"
|
|
421
|
+
if y_axis not in ["left", "right", False]:
|
|
422
|
+
raise ValueError(f"Unknown y_axis specification: '{y_axis}'.")
|
|
423
|
+
return y_axis
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def create_tick_labels(tick_values, decimal_places=2):
|
|
427
|
+
"""
|
|
428
|
+
If tick_values are numeric, round the labels to X decimal_places, but do not print
|
|
429
|
+
decimals if all values are integers
|
|
430
|
+
"""
|
|
431
|
+
try:
|
|
432
|
+
integer_ticks = np.all(np.round(tick_values) == tick_values)
|
|
433
|
+
except TypeError:
|
|
434
|
+
return tick_values
|
|
435
|
+
label_precision = 0 if integer_ticks else decimal_places
|
|
436
|
+
return [f"{lab:.{label_precision}f}" for lab in tick_values]
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def clip_ts(ts, x_min, x_max, max_num_trees=None):
|
|
440
|
+
"""
|
|
441
|
+
Culls the edges of the tree sequence outside the limits of x_min and x_max if
|
|
442
|
+
necessary, and flags internal trees for omission if there are more than
|
|
443
|
+
max_num_trees in the tree sequence
|
|
444
|
+
|
|
445
|
+
Returns the new tree sequence using the same genomic scale, and an
|
|
446
|
+
array specifying which trees to actually plot from it. This array contains
|
|
447
|
+
information about whether a plotted tree was clipped, because clipping can
|
|
448
|
+
cause the rightmost and leftmost tree in this new TS to have reduced spans, and
|
|
449
|
+
should be displayed by omitting the appropriate breakpoint.
|
|
450
|
+
|
|
451
|
+
If x_min is None, we take it to be 0 if the first tree has edges or sites, or
|
|
452
|
+
``min(edges.left)`` if the first tree represents an empty region.
|
|
453
|
+
Similarly, if x_max is None we take it to be ``ts.sequence_length`` if the last tree
|
|
454
|
+
has edges or mutations, or ``ts.last().interval.left`` if the last tree represents
|
|
455
|
+
an empty region.
|
|
456
|
+
|
|
457
|
+
To plot the full ts, including empty flanking regions, specify x_limits of
|
|
458
|
+
[0, seq_len].
|
|
459
|
+
|
|
460
|
+
"""
|
|
461
|
+
edges = ts.tables.edges
|
|
462
|
+
sites = ts.tables.sites
|
|
463
|
+
offsets = Offsets()
|
|
464
|
+
if x_min is None:
|
|
465
|
+
if ts.num_edges == 0:
|
|
466
|
+
if ts.num_sites == 0:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
"To plot an empty tree sequence, specify x_lim=[0, sequence_length]"
|
|
469
|
+
)
|
|
470
|
+
x_min = 0
|
|
471
|
+
else:
|
|
472
|
+
x_min = np.min(edges.left)
|
|
473
|
+
if ts.num_sites > 0 and np.min(sites.position) < x_min:
|
|
474
|
+
x_min = 0 # First region has no edges, but does have sites => keep
|
|
475
|
+
if x_max is None:
|
|
476
|
+
if ts.num_edges == 0:
|
|
477
|
+
if ts.num_sites == 0:
|
|
478
|
+
raise ValueError(
|
|
479
|
+
"To plot an empty tree sequence, specify x_lim=[0, sequence_length]"
|
|
480
|
+
)
|
|
481
|
+
x_max = ts.sequence_length
|
|
482
|
+
else:
|
|
483
|
+
x_max = np.max(edges.right)
|
|
484
|
+
if ts.num_sites > 0 and np.max(sites.position) > x_max:
|
|
485
|
+
x_max = ts.sequence_length # Last region has sites but no edges => keep
|
|
486
|
+
|
|
487
|
+
if max_num_trees is None:
|
|
488
|
+
max_num_trees = np.inf
|
|
489
|
+
|
|
490
|
+
if max_num_trees < 2:
|
|
491
|
+
raise ValueError("Must show at least 2 trees when clipping a tree sequence")
|
|
492
|
+
|
|
493
|
+
if (x_min > 0) or (x_max < ts.sequence_length):
|
|
494
|
+
old_breaks = ts.breakpoints(as_array=True)
|
|
495
|
+
offsets.tree = np.searchsorted(old_breaks, x_min, "right") - 2
|
|
496
|
+
offsets.site = np.searchsorted(sites.position, x_min)
|
|
497
|
+
offsets.mutation = np.searchsorted(ts.tables.mutations.site, offsets.site)
|
|
498
|
+
ts = ts.keep_intervals([[x_min, x_max]], simplify=False)
|
|
499
|
+
if ts.num_edges == 0:
|
|
500
|
+
raise ValueError(
|
|
501
|
+
f"Can't limit plotting from {x_min} to {x_max} as whole region is empty"
|
|
502
|
+
)
|
|
503
|
+
edges = ts.tables.edges
|
|
504
|
+
sites = ts.tables.sites
|
|
505
|
+
trees_start = np.min(edges.left)
|
|
506
|
+
trees_end = np.max(edges.right)
|
|
507
|
+
tree_status = np.zeros(ts.num_trees, dtype=np.uint8)
|
|
508
|
+
# Are the leftmost/rightmost regions completely empty - if so, don't plot them
|
|
509
|
+
if 0 < x_min <= trees_start and (
|
|
510
|
+
ts.num_sites == 0 or trees_start <= np.min(sites.position)
|
|
511
|
+
):
|
|
512
|
+
tree_status[0] = OMIT
|
|
513
|
+
if trees_end <= x_max < ts.sequence_length and (
|
|
514
|
+
ts.num_sites == 0 or trees_end >= np.max(sites.position)
|
|
515
|
+
):
|
|
516
|
+
tree_status[-1] = OMIT
|
|
517
|
+
|
|
518
|
+
# Which breakpoints are new ones, as a result of clipping
|
|
519
|
+
new_breaks = np.logical_not(np.isin(ts.breakpoints(as_array=True), old_breaks))
|
|
520
|
+
tree_status[new_breaks[:-1]] |= LEFT_CLIP
|
|
521
|
+
tree_status[new_breaks[1:]] |= RIGHT_CLIP
|
|
522
|
+
else:
|
|
523
|
+
tree_status = np.zeros(ts.num_trees, dtype=np.uint8)
|
|
524
|
+
|
|
525
|
+
first_tree = 1 if tree_status[0] & OMIT else 0
|
|
526
|
+
last_tree = ts.num_trees - 2 if tree_status[-1] & OMIT else ts.num_trees - 1
|
|
527
|
+
num_shown_trees = last_tree - first_tree + 1
|
|
528
|
+
if num_shown_trees > max_num_trees:
|
|
529
|
+
num_start_trees = max_num_trees // 2 + (1 if max_num_trees % 2 else 0)
|
|
530
|
+
num_end_trees = max_num_trees // 2
|
|
531
|
+
assert num_start_trees + num_end_trees == max_num_trees
|
|
532
|
+
tree_status[
|
|
533
|
+
(first_tree + num_start_trees) : (last_tree - num_end_trees + 1)
|
|
534
|
+
] = (OMIT | OMIT_MIDDLE)
|
|
535
|
+
|
|
536
|
+
return ts, tree_status, offsets
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def check_y_ticks(ticks: list | Mapping | None) -> Mapping:
|
|
540
|
+
"""
|
|
541
|
+
Later we might want to implement a tick locator function, such that e.g. ticks=5
|
|
542
|
+
selects ~5 nicely spaced tick locations (with sensible behaviour for log scales)
|
|
543
|
+
"""
|
|
544
|
+
if ticks is None:
|
|
545
|
+
return {}
|
|
546
|
+
if isinstance(ticks, Mapping):
|
|
547
|
+
return dict(zip(ticks, create_tick_labels(list(ticks.values()))))
|
|
548
|
+
return dict(zip(ticks, create_tick_labels(ticks)))
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def rnd(x):
|
|
552
|
+
"""
|
|
553
|
+
Round a number so that the output SVG doesn't have unneeded precision
|
|
554
|
+
"""
|
|
555
|
+
digits = 6
|
|
556
|
+
if x == 0 or not math.isfinite(x):
|
|
557
|
+
return x
|
|
558
|
+
digits -= math.ceil(math.log10(abs(x)))
|
|
559
|
+
x = round(x, digits)
|
|
560
|
+
if int(x) == x:
|
|
561
|
+
return int(x)
|
|
562
|
+
return x
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def bold_integer(number):
|
|
566
|
+
# For simple integers, it's easier to use bold unicode characters
|
|
567
|
+
# than to try to get the SVG to render a bold font for part of a string
|
|
568
|
+
return "".join("𝟎𝟏𝟐𝟑𝟒𝟓𝟔𝟕𝟖𝟗"[int(digit)] for digit in str(number))
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def edge_and_sample_nodes(ts, omit_regions=None):
|
|
572
|
+
"""
|
|
573
|
+
Return ids of nodes which are mentioned in an edge in this tree sequence or which
|
|
574
|
+
are samples: nodes not connected to an edge are often found if x_lim is specified.
|
|
575
|
+
"""
|
|
576
|
+
if omit_regions is None or len(omit_regions) == 0:
|
|
577
|
+
ids = np.concatenate((ts.edges_child, ts.edges_parent))
|
|
578
|
+
else:
|
|
579
|
+
ids = np.array([], dtype=ts.edges_child.dtype)
|
|
580
|
+
edges = ts.tables.edges
|
|
581
|
+
assert omit_regions.shape[1] == 2
|
|
582
|
+
omit_regions = omit_regions.flatten()
|
|
583
|
+
assert np.all(omit_regions == np.unique(omit_regions)) # Check they're in order
|
|
584
|
+
use_regions = np.concatenate(([0.0], omit_regions, [ts.sequence_length]))
|
|
585
|
+
use_regions = use_regions.reshape(-1, 2)
|
|
586
|
+
for left, right in use_regions:
|
|
587
|
+
used_edges = edges[np.logical_and(edges.left >= left, edges.right < right)]
|
|
588
|
+
ids = np.concatenate((ids, used_edges.child, used_edges.parent))
|
|
589
|
+
return np.unique(
|
|
590
|
+
np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0]))
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def _postorder_tracked_node_traversal(tree, root, collapse_tracked, key_dict=None):
|
|
595
|
+
# Postorder traversal that only descends into subtrees if they contain
|
|
596
|
+
# a tracked node. Additionally, if collapse_tracked is not None, it is
|
|
597
|
+
# interpreted as a proportion, so that we do not descend into a subtree if
|
|
598
|
+
# that proportion or greater of the samples in the subtree are tracked.
|
|
599
|
+
# If key_dict is provided, use this to sort the children. This allows
|
|
600
|
+
# us to put e.g. the subtrees containing the most tracked nodes first.
|
|
601
|
+
# Private function, for use only in drawing.postorder_tracked_minlex_traversal()
|
|
602
|
+
|
|
603
|
+
# If we deliberately specify the virtual root, it should also be returned
|
|
604
|
+
is_virtual_root = root == tree.virtual_root
|
|
605
|
+
if root == tskit.NULL:
|
|
606
|
+
root = tree.virtual_root
|
|
607
|
+
stack = [(root, False)]
|
|
608
|
+
while stack:
|
|
609
|
+
u, visited = stack.pop()
|
|
610
|
+
if visited:
|
|
611
|
+
if u != tree.virtual_root or is_virtual_root:
|
|
612
|
+
yield u
|
|
613
|
+
else:
|
|
614
|
+
if tree.num_children(u) == 0:
|
|
615
|
+
yield u
|
|
616
|
+
elif tree.num_tracked_samples(u) == 0:
|
|
617
|
+
yield u
|
|
618
|
+
elif (
|
|
619
|
+
collapse_tracked is not None
|
|
620
|
+
and tree.num_children(u) != 1
|
|
621
|
+
and tree.num_tracked_samples(u)
|
|
622
|
+
>= collapse_tracked * tree.num_samples(u)
|
|
623
|
+
):
|
|
624
|
+
yield u
|
|
625
|
+
else:
|
|
626
|
+
stack.append((u, True))
|
|
627
|
+
if key_dict is None:
|
|
628
|
+
stack.extend((c, False) for c in tree.children(u))
|
|
629
|
+
else:
|
|
630
|
+
stack.extend(
|
|
631
|
+
sorted(
|
|
632
|
+
((c, False) for c in tree.children(u)),
|
|
633
|
+
key=lambda v: key_dict[v[0]],
|
|
634
|
+
reverse=True,
|
|
635
|
+
)
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def _postorder_tracked_minlex_traversal(tree, root=None, *, collapse_tracked=None):
|
|
640
|
+
"""
|
|
641
|
+
Postorder traversal for drawing purposes that places child nodes with the
|
|
642
|
+
most tracked sample descendants first (then sorts ties by minlex on leaf node ids).
|
|
643
|
+
Additionally, this traversal only descends into subtrees if they contain a tracked
|
|
644
|
+
node, and may not descend into other subtree, if the ``collapse_tracked``
|
|
645
|
+
parameter is set to a numeric value. More specifically, if the proportion of
|
|
646
|
+
tracked samples in the subtree is greater than or equal to ``collapse_tracked``,
|
|
647
|
+
the subtree is not descended into.
|
|
648
|
+
"""
|
|
649
|
+
|
|
650
|
+
key_dict = {}
|
|
651
|
+
parent_array = tree.parent_array
|
|
652
|
+
prev = tree.virtual_root
|
|
653
|
+
if root is None:
|
|
654
|
+
root = tskit.NULL
|
|
655
|
+
for u in _postorder_tracked_node_traversal(tree, root, collapse_tracked):
|
|
656
|
+
is_tip = parent_array[prev] != u
|
|
657
|
+
prev = u
|
|
658
|
+
if is_tip:
|
|
659
|
+
# Sort by number of tracked samples (desc), then by minlex
|
|
660
|
+
key_dict[u] = (-tree.num_tracked_samples(u), u)
|
|
661
|
+
else:
|
|
662
|
+
min_tip_id = min(key_dict[v][1] for v in tree.children(u) if v in key_dict)
|
|
663
|
+
key_dict[u] = (-tree.num_tracked_samples(u), min_tip_id)
|
|
664
|
+
|
|
665
|
+
return _postorder_tracked_node_traversal(
|
|
666
|
+
tree, root, collapse_tracked, key_dict=key_dict
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def draw_tree(
|
|
671
|
+
tree,
|
|
672
|
+
width=None,
|
|
673
|
+
height=None,
|
|
674
|
+
node_labels=None,
|
|
675
|
+
node_colours=None,
|
|
676
|
+
mutation_labels=None,
|
|
677
|
+
mutation_colours=None,
|
|
678
|
+
format=None, # noqa A002
|
|
679
|
+
edge_colours=None,
|
|
680
|
+
time_scale=None,
|
|
681
|
+
tree_height_scale=None,
|
|
682
|
+
max_time=None,
|
|
683
|
+
min_time=None,
|
|
684
|
+
max_tree_height=None,
|
|
685
|
+
order=None,
|
|
686
|
+
omit_sites=None,
|
|
687
|
+
):
|
|
688
|
+
if time_scale is None and tree_height_scale is not None:
|
|
689
|
+
time_scale = tree_height_scale
|
|
690
|
+
# Deprecated in 0.3.6
|
|
691
|
+
warnings.warn(
|
|
692
|
+
"tree_height_scale is deprecated; use time_scale instead",
|
|
693
|
+
FutureWarning,
|
|
694
|
+
stacklevel=4,
|
|
695
|
+
)
|
|
696
|
+
if max_time is None and max_tree_height is not None:
|
|
697
|
+
max_time = max_tree_height
|
|
698
|
+
# Deprecated in 0.3.6
|
|
699
|
+
warnings.warn(
|
|
700
|
+
"max_tree_height is deprecated; use max_time instead",
|
|
701
|
+
FutureWarning,
|
|
702
|
+
stacklevel=4,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# See tree.draw() for documentation on these arguments.
|
|
706
|
+
fmt = check_format(format)
|
|
707
|
+
if fmt == "svg":
|
|
708
|
+
if width is None:
|
|
709
|
+
width = 200
|
|
710
|
+
if height is None:
|
|
711
|
+
height = 200
|
|
712
|
+
|
|
713
|
+
def remap_style(original_map, new_key, none_value):
|
|
714
|
+
if original_map is None:
|
|
715
|
+
return None
|
|
716
|
+
new_map = {}
|
|
717
|
+
for key, value in original_map.items():
|
|
718
|
+
if value is None:
|
|
719
|
+
new_map[key] = {"style": none_value}
|
|
720
|
+
else:
|
|
721
|
+
new_map[key] = {"style": f"{new_key}:{value};"}
|
|
722
|
+
return new_map
|
|
723
|
+
|
|
724
|
+
# Set style rather than fill & stroke directly to override top stylesheet
|
|
725
|
+
# Old semantics were to not draw the node if colour is None.
|
|
726
|
+
# Setting opacity to zero has the same effect.
|
|
727
|
+
node_attrs = remap_style(node_colours, "fill", "fill-opacity:0;")
|
|
728
|
+
edge_attrs = remap_style(edge_colours, "stroke", "stroke-opacity:0;")
|
|
729
|
+
mutation_attrs = remap_style(mutation_colours, "fill", "fill-opacity:0;")
|
|
730
|
+
|
|
731
|
+
node_label_attrs = None
|
|
732
|
+
tree = SvgTree(
|
|
733
|
+
tree,
|
|
734
|
+
(width, height),
|
|
735
|
+
node_labels=node_labels,
|
|
736
|
+
mutation_labels=mutation_labels,
|
|
737
|
+
time_scale=time_scale,
|
|
738
|
+
max_time=max_time,
|
|
739
|
+
min_time=min_time,
|
|
740
|
+
node_attrs=node_attrs,
|
|
741
|
+
edge_attrs=edge_attrs,
|
|
742
|
+
node_label_attrs=node_label_attrs,
|
|
743
|
+
mutation_attrs=mutation_attrs,
|
|
744
|
+
order=order,
|
|
745
|
+
omit_sites=omit_sites,
|
|
746
|
+
)
|
|
747
|
+
return SVGString(tree.drawing.tostring())
|
|
748
|
+
|
|
749
|
+
else:
|
|
750
|
+
if width is not None:
|
|
751
|
+
raise ValueError("Text trees do not support width")
|
|
752
|
+
if height is not None:
|
|
753
|
+
raise ValueError("Text trees do not support height")
|
|
754
|
+
if mutation_labels is not None:
|
|
755
|
+
raise ValueError("Text trees do not support mutation_labels")
|
|
756
|
+
if mutation_colours is not None:
|
|
757
|
+
raise ValueError("Text trees do not support mutation_colours")
|
|
758
|
+
if node_colours is not None:
|
|
759
|
+
raise ValueError("Text trees do not support node_colours")
|
|
760
|
+
if edge_colours is not None:
|
|
761
|
+
raise ValueError("Text trees do not support edge_colours")
|
|
762
|
+
if time_scale is not None:
|
|
763
|
+
raise ValueError("Text trees do not support time_scale")
|
|
764
|
+
|
|
765
|
+
use_ascii = fmt == "ascii"
|
|
766
|
+
text_tree = VerticalTextTree(
|
|
767
|
+
tree,
|
|
768
|
+
node_labels=node_labels,
|
|
769
|
+
max_time=max_time,
|
|
770
|
+
min_time=min_time,
|
|
771
|
+
use_ascii=use_ascii,
|
|
772
|
+
orientation=TOP,
|
|
773
|
+
order=order,
|
|
774
|
+
)
|
|
775
|
+
return str(text_tree)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
def add_class(attrs_dict, classes_str):
|
|
779
|
+
"""Adds the classes_str to the 'class' key in attrs_dict, or creates it"""
|
|
780
|
+
try:
|
|
781
|
+
attrs_dict["class"] += " " + classes_str
|
|
782
|
+
except KeyError:
|
|
783
|
+
attrs_dict["class"] = classes_str
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
@dataclass
|
|
787
|
+
class Plotbox:
|
|
788
|
+
total_size: list
|
|
789
|
+
pad_top: float = 0
|
|
790
|
+
pad_left: float = 0
|
|
791
|
+
pad_bottom: float = 0
|
|
792
|
+
pad_right: float = 0
|
|
793
|
+
|
|
794
|
+
def set_padding(self, top, left, bottom, right):
|
|
795
|
+
self.pad_top = top
|
|
796
|
+
self.pad_left = left
|
|
797
|
+
self.pad_bottom = bottom
|
|
798
|
+
self.pad_right = right
|
|
799
|
+
self._check()
|
|
800
|
+
|
|
801
|
+
@property
|
|
802
|
+
def max_x(self):
|
|
803
|
+
return self.total_size[0]
|
|
804
|
+
|
|
805
|
+
@property
|
|
806
|
+
def max_y(self):
|
|
807
|
+
return self.total_size[1]
|
|
808
|
+
|
|
809
|
+
@property
|
|
810
|
+
def top(self): # Alias for consistency with top & bottom
|
|
811
|
+
return self.pad_top
|
|
812
|
+
|
|
813
|
+
@property
|
|
814
|
+
def left(self): # Alias for consistency with top & bottom
|
|
815
|
+
return self.pad_left
|
|
816
|
+
|
|
817
|
+
@property
|
|
818
|
+
def bottom(self):
|
|
819
|
+
return self.max_y - self.pad_bottom
|
|
820
|
+
|
|
821
|
+
@property
|
|
822
|
+
def right(self):
|
|
823
|
+
return self.max_x - self.pad_right
|
|
824
|
+
|
|
825
|
+
@property
|
|
826
|
+
def width(self):
|
|
827
|
+
return self.right - self.left
|
|
828
|
+
|
|
829
|
+
@property
|
|
830
|
+
def height(self):
|
|
831
|
+
return self.bottom - self.top
|
|
832
|
+
|
|
833
|
+
def __post_init__(self):
|
|
834
|
+
self._check()
|
|
835
|
+
|
|
836
|
+
def _check(self):
|
|
837
|
+
if self.width < 1 or self.height < 1:
|
|
838
|
+
raise ValueError("Image size too small to fit")
|
|
839
|
+
|
|
840
|
+
def draw(self, dwg, add_to, colour="grey"):
|
|
841
|
+
# used for debugging
|
|
842
|
+
add_to.add(
|
|
843
|
+
dwg.rect(
|
|
844
|
+
(0, 0),
|
|
845
|
+
(self.max_x, self.max_y),
|
|
846
|
+
fill="white",
|
|
847
|
+
fill_opacity=0,
|
|
848
|
+
stroke=colour,
|
|
849
|
+
stroke_dasharray="15,15",
|
|
850
|
+
class_="outer_plotbox",
|
|
851
|
+
)
|
|
852
|
+
)
|
|
853
|
+
add_to.add(
|
|
854
|
+
dwg.rect(
|
|
855
|
+
(self.left, self.top),
|
|
856
|
+
(self.width, self.height),
|
|
857
|
+
fill="white",
|
|
858
|
+
fill_opacity=0,
|
|
859
|
+
stroke=colour,
|
|
860
|
+
stroke_dasharray="5,5",
|
|
861
|
+
class_="inner_plotbox",
|
|
862
|
+
)
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
class SvgPlot:
|
|
867
|
+
"""
|
|
868
|
+
The base class for plotting any box to canvas
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
text_height = 14 # May want to calculate this based on a font size
|
|
872
|
+
line_height = text_height * 1.2 # allowing padding above and below a line
|
|
873
|
+
default_width = 200 # for a single tree
|
|
874
|
+
default_height = 200
|
|
875
|
+
|
|
876
|
+
def __init__(
|
|
877
|
+
self,
|
|
878
|
+
size,
|
|
879
|
+
svg_class,
|
|
880
|
+
root_svg_attributes=None,
|
|
881
|
+
canvas_size=None,
|
|
882
|
+
preamble=None,
|
|
883
|
+
):
|
|
884
|
+
"""
|
|
885
|
+
Creates self.drawing, an svgwrite.Drawing object for further use, and populates
|
|
886
|
+
it with a base group. The root_groups will be populated with
|
|
887
|
+
items that can be accessed from the outside, such as the plotbox, axes, etc.
|
|
888
|
+
"""
|
|
889
|
+
|
|
890
|
+
if root_svg_attributes is None:
|
|
891
|
+
root_svg_attributes = {}
|
|
892
|
+
if canvas_size is None:
|
|
893
|
+
canvas_size = size
|
|
894
|
+
dwg = Drawing(size=canvas_size, **root_svg_attributes)
|
|
895
|
+
|
|
896
|
+
self.preamble = preamble
|
|
897
|
+
self.image_size = size
|
|
898
|
+
self.plotbox = Plotbox(size)
|
|
899
|
+
self.root_groups = {}
|
|
900
|
+
self.svg_class = svg_class
|
|
901
|
+
self.timescaling = None
|
|
902
|
+
self.root_svg_attributes = root_svg_attributes
|
|
903
|
+
self.dwg_base = dwg.add(dwg.g(class_=svg_class))
|
|
904
|
+
self.drawing = dwg
|
|
905
|
+
|
|
906
|
+
def draw(self, path=None):
|
|
907
|
+
if self.preamble is not None:
|
|
908
|
+
self.drawing.root.children[0] = self.preamble
|
|
909
|
+
output = self.drawing.tostring()
|
|
910
|
+
if path is not None:
|
|
911
|
+
# TODO remove the 'pretty' when we are done debugging this.
|
|
912
|
+
self.drawing.saveas(path, pretty=True)
|
|
913
|
+
return SVGString(output)
|
|
914
|
+
|
|
915
|
+
def get_plotbox(self):
|
|
916
|
+
"""
|
|
917
|
+
Get the svgwrite plotbox, creating it if necessary.
|
|
918
|
+
"""
|
|
919
|
+
if "plotbox" not in self.root_groups:
|
|
920
|
+
dwg = self.drawing
|
|
921
|
+
self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox"))
|
|
922
|
+
return self.root_groups["plotbox"]
|
|
923
|
+
|
|
924
|
+
def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs):
|
|
925
|
+
"""
|
|
926
|
+
Add the text to the elem within a group; allows text rotations to work smoothly,
|
|
927
|
+
otherwise, if x & y parameters are used to position text, rotations applied to
|
|
928
|
+
the text tag occur around the (0,0) point of the containing group
|
|
929
|
+
"""
|
|
930
|
+
dwg = self.drawing
|
|
931
|
+
group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"}
|
|
932
|
+
if group_class is not None:
|
|
933
|
+
group_attributes["class_"] = group_class
|
|
934
|
+
grp = add_to.add(dwg.g(**group_attributes))
|
|
935
|
+
grp.add(dwg.text(text, **kwargs))
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
class SvgSkippedPlot(SvgPlot):
|
|
939
|
+
def __init__(
|
|
940
|
+
self,
|
|
941
|
+
size,
|
|
942
|
+
num_skipped,
|
|
943
|
+
):
|
|
944
|
+
super().__init__(
|
|
945
|
+
size,
|
|
946
|
+
svg_class="skipped",
|
|
947
|
+
)
|
|
948
|
+
container = self.get_plotbox()
|
|
949
|
+
x = self.plotbox.width / 2
|
|
950
|
+
y = self.plotbox.height / 2
|
|
951
|
+
self.add_text_in_group(
|
|
952
|
+
f"{num_skipped} trees",
|
|
953
|
+
container,
|
|
954
|
+
(x, y - self.line_height / 2),
|
|
955
|
+
text_anchor="middle",
|
|
956
|
+
)
|
|
957
|
+
self.add_text_in_group(
|
|
958
|
+
"skipped", container, (x, y + self.line_height / 2), text_anchor="middle"
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
class SvgAxisPlot(SvgPlot):
|
|
963
|
+
"""
|
|
964
|
+
The class used for plotting either a tree or a tree sequence as an SVG file
|
|
965
|
+
"""
|
|
966
|
+
|
|
967
|
+
standard_style = (
|
|
968
|
+
".background path {fill: #808080; fill-opacity: 0}"
|
|
969
|
+
".background path:nth-child(odd) {fill-opacity: .1}"
|
|
970
|
+
".x-regions rect {fill: yellow; stroke: black; opacity: 0.5}" # opaque 4 overlap
|
|
971
|
+
".axes {font-size: 14px}"
|
|
972
|
+
".x-axis .tick .lab {font-weight: bold; dominant-baseline: hanging}"
|
|
973
|
+
".axes, .tree {font-size: 14px; text-anchor: middle}"
|
|
974
|
+
".axes line, .edge {stroke: black; fill: none}"
|
|
975
|
+
".axes .ax-skip {stroke-dasharray: 4}"
|
|
976
|
+
".y-axis .grid {stroke: #FAFAFA}"
|
|
977
|
+
".node > .sym {fill: black; stroke: none}"
|
|
978
|
+
".site > .sym {stroke: black}"
|
|
979
|
+
".mut text {fill: red; font-style: italic}"
|
|
980
|
+
".mut.extra text {fill: hotpink}"
|
|
981
|
+
".mut line {fill: none; stroke: none}" # Default hide mut line to expose edges
|
|
982
|
+
".mut .sym {fill: none; stroke: red}"
|
|
983
|
+
".mut.extra .sym {stroke: hotpink}"
|
|
984
|
+
".node .mut .sym {stroke-width: 1.5px}"
|
|
985
|
+
".tree text, .tree-sequence text {dominant-baseline: central}"
|
|
986
|
+
".plotbox .lab.lft {text-anchor: end}"
|
|
987
|
+
".plotbox .lab.rgt {text-anchor: start}"
|
|
988
|
+
".polytomy line {stroke: black; stroke-dasharray: 1px, 1px}"
|
|
989
|
+
".polytomy text {paint-order:stroke;stroke-width:0.3em;stroke:white}"
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
# TODO: we may want to make some of the constants below into parameters
|
|
993
|
+
root_branch_fraction = 1 / 8 # Rel root branch len, unless it has a timed mutation
|
|
994
|
+
default_tick_length = 5
|
|
995
|
+
default_tick_length_site = 10
|
|
996
|
+
# Placement of the axes lines within the padding - not used unless axis is plotted
|
|
997
|
+
default_x_axis_offset = 20
|
|
998
|
+
default_y_axis_offset = 40
|
|
999
|
+
|
|
1000
|
+
def __init__(
|
|
1001
|
+
self,
|
|
1002
|
+
ts,
|
|
1003
|
+
size,
|
|
1004
|
+
root_svg_attributes,
|
|
1005
|
+
style,
|
|
1006
|
+
svg_class,
|
|
1007
|
+
time_scale,
|
|
1008
|
+
x_axis=None,
|
|
1009
|
+
y_axis=None,
|
|
1010
|
+
x_label=None,
|
|
1011
|
+
y_label=None,
|
|
1012
|
+
offsets=None,
|
|
1013
|
+
debug_box=None,
|
|
1014
|
+
omit_sites=None,
|
|
1015
|
+
canvas_size=None,
|
|
1016
|
+
mutation_titles=None,
|
|
1017
|
+
preamble=None,
|
|
1018
|
+
):
|
|
1019
|
+
super().__init__(
|
|
1020
|
+
size,
|
|
1021
|
+
svg_class,
|
|
1022
|
+
root_svg_attributes,
|
|
1023
|
+
canvas_size,
|
|
1024
|
+
preamble=preamble,
|
|
1025
|
+
)
|
|
1026
|
+
self.ts = ts
|
|
1027
|
+
dwg = self.drawing
|
|
1028
|
+
# Put all styles in a single stylesheet (required for Inkscape 0.92)
|
|
1029
|
+
style = self.standard_style + ("" if style is None else style)
|
|
1030
|
+
dwg.defs.add(dwg.style(style))
|
|
1031
|
+
self.debug_box = debug_box
|
|
1032
|
+
self.time_scale = check_time_scale(time_scale)
|
|
1033
|
+
self.y_axis = check_y_axis(y_axis)
|
|
1034
|
+
self.x_axis = x_axis
|
|
1035
|
+
if x_label is None and x_axis:
|
|
1036
|
+
x_label = "Genome position"
|
|
1037
|
+
if y_label is None and y_axis:
|
|
1038
|
+
if time_scale == "rank":
|
|
1039
|
+
y_label = "Node time"
|
|
1040
|
+
else:
|
|
1041
|
+
y_label = "Time ago"
|
|
1042
|
+
if ts.time_units != tskit.TIME_UNITS_UNKNOWN:
|
|
1043
|
+
y_label += f" ({ts.time_units})"
|
|
1044
|
+
self.x_label = x_label
|
|
1045
|
+
self.y_label = y_label
|
|
1046
|
+
self.offsets = Offsets() if offsets is None else offsets
|
|
1047
|
+
self.omit_sites = omit_sites
|
|
1048
|
+
self.mutation_titles = {} if mutation_titles is None else mutation_titles
|
|
1049
|
+
self.mutations_outside_tree = set() # mutations in here get an additional class
|
|
1050
|
+
|
|
1051
|
+
def set_spacing(self, top=0, left=0, bottom=0, right=0):
|
|
1052
|
+
"""
|
|
1053
|
+
Set edges, but allow space for axes etc
|
|
1054
|
+
"""
|
|
1055
|
+
self.x_axis_offset = self.default_x_axis_offset
|
|
1056
|
+
self.y_axis_offset = self.default_y_axis_offset
|
|
1057
|
+
if self.x_label:
|
|
1058
|
+
self.x_axis_offset += self.line_height
|
|
1059
|
+
if self.y_label:
|
|
1060
|
+
self.y_axis_offset += self.line_height
|
|
1061
|
+
if self.x_axis:
|
|
1062
|
+
bottom += self.x_axis_offset
|
|
1063
|
+
if self.y_axis == "left":
|
|
1064
|
+
left = (
|
|
1065
|
+
self.y_axis_offset
|
|
1066
|
+
) # Override user-provided values, so y-axis is at x=0
|
|
1067
|
+
if self.y_axis == "right":
|
|
1068
|
+
right = self.y_axis_offset
|
|
1069
|
+
self.plotbox.set_padding(top, left, bottom, right)
|
|
1070
|
+
if self.debug_box:
|
|
1071
|
+
self.root_groups["debug"] = self.dwg_base.add(
|
|
1072
|
+
self.drawing.g(class_="debug")
|
|
1073
|
+
)
|
|
1074
|
+
self.plotbox.draw(self.drawing, self.root_groups["debug"])
|
|
1075
|
+
|
|
1076
|
+
def get_axes(self):
|
|
1077
|
+
if "axes" not in self.root_groups:
|
|
1078
|
+
self.root_groups["axes"] = self.dwg_base.add(self.drawing.g(class_="axes"))
|
|
1079
|
+
return self.root_groups["axes"]
|
|
1080
|
+
|
|
1081
|
+
def draw_x_axis(
|
|
1082
|
+
self,
|
|
1083
|
+
tick_positions=None, # np.array of ax ticks below (+ above if sites is None)
|
|
1084
|
+
tick_labels=None, # Tick labels below axis. If None, use the position value
|
|
1085
|
+
tick_length_lower=default_tick_length,
|
|
1086
|
+
tick_length_upper=None, # If None, use the same as tick_length_lower
|
|
1087
|
+
site_muts=None, # A dict of site id => mutation to plot as ticks on the x axis
|
|
1088
|
+
alternate_dash_positions=None, # Where to alternate the axis from solid to dash
|
|
1089
|
+
x_regions=None, # A dict of (left, right):label items to place in boxes
|
|
1090
|
+
):
|
|
1091
|
+
if not self.x_axis:
|
|
1092
|
+
return
|
|
1093
|
+
if alternate_dash_positions is None:
|
|
1094
|
+
alternate_dash_positions = np.array([])
|
|
1095
|
+
if x_regions is None:
|
|
1096
|
+
x_regions = {}
|
|
1097
|
+
dwg = self.drawing
|
|
1098
|
+
axes = self.get_axes()
|
|
1099
|
+
x_axis = axes.add(dwg.g(class_="x-axis"))
|
|
1100
|
+
if self.x_label:
|
|
1101
|
+
self.add_text_in_group(
|
|
1102
|
+
self.x_label,
|
|
1103
|
+
x_axis,
|
|
1104
|
+
pos=((self.plotbox.left + self.plotbox.right) / 2, self.plotbox.max_y),
|
|
1105
|
+
group_class="title",
|
|
1106
|
+
class_="lab",
|
|
1107
|
+
transform="translate(0 -11)",
|
|
1108
|
+
text_anchor="middle",
|
|
1109
|
+
)
|
|
1110
|
+
if len(x_regions) > 0:
|
|
1111
|
+
regions_group = x_axis.add(dwg.g(class_="x-regions"))
|
|
1112
|
+
for i, ((left, right), label) in enumerate(x_regions.items()):
|
|
1113
|
+
if not (0 <= left < right <= self.ts.sequence_length):
|
|
1114
|
+
raise ValueError(
|
|
1115
|
+
f"Invalid coordinates ({left} to {right}) for x-axis region"
|
|
1116
|
+
)
|
|
1117
|
+
x1 = self.x_transform(left)
|
|
1118
|
+
x2 = self.x_transform(right)
|
|
1119
|
+
y = self.plotbox.max_y - self.x_axis_offset
|
|
1120
|
+
region = regions_group.add(dwg.g(class_=f"r{i}"))
|
|
1121
|
+
region.add(
|
|
1122
|
+
dwg.rect((x1, y), (x2 - x1, self.line_height), class_="r{i}")
|
|
1123
|
+
)
|
|
1124
|
+
self.add_text_in_group(
|
|
1125
|
+
label,
|
|
1126
|
+
region,
|
|
1127
|
+
pos=((x2 + x1) / 2, y + self.line_height / 2),
|
|
1128
|
+
class_="lab",
|
|
1129
|
+
text_anchor="middle",
|
|
1130
|
+
)
|
|
1131
|
+
if tick_length_upper is None:
|
|
1132
|
+
tick_length_upper = tick_length_lower
|
|
1133
|
+
y = rnd(self.plotbox.max_y - self.x_axis_offset)
|
|
1134
|
+
dash_locs = np.concatenate(
|
|
1135
|
+
(
|
|
1136
|
+
[self.plotbox.left],
|
|
1137
|
+
self.x_transform(alternate_dash_positions),
|
|
1138
|
+
[self.plotbox.right],
|
|
1139
|
+
)
|
|
1140
|
+
)
|
|
1141
|
+
for i, (x1, x2) in enumerate(zip(dash_locs[:-1], dash_locs[1:])):
|
|
1142
|
+
x_axis.add(
|
|
1143
|
+
dwg.line(
|
|
1144
|
+
(rnd(x1), y),
|
|
1145
|
+
(rnd(x2), y),
|
|
1146
|
+
class_="ax-skip" if i % 2 else "ax-line",
|
|
1147
|
+
)
|
|
1148
|
+
)
|
|
1149
|
+
if tick_positions is not None:
|
|
1150
|
+
if tick_labels is None or isinstance(tick_labels, np.ndarray):
|
|
1151
|
+
if tick_labels is None:
|
|
1152
|
+
tick_labels = tick_positions
|
|
1153
|
+
tick_labels = create_tick_labels(tick_labels) # format integers
|
|
1154
|
+
|
|
1155
|
+
upper_length = -tick_length_upper if site_muts is None else 0
|
|
1156
|
+
ticks_group = x_axis.add(dwg.g(class_="ticks"))
|
|
1157
|
+
for pos, lab in itertools.zip_longest(tick_positions, tick_labels):
|
|
1158
|
+
tick = ticks_group.add(
|
|
1159
|
+
dwg.g(
|
|
1160
|
+
class_="tick",
|
|
1161
|
+
transform=f"translate({rnd(self.x_transform(pos))} {y})",
|
|
1162
|
+
)
|
|
1163
|
+
)
|
|
1164
|
+
tick.add(dwg.line((0, rnd(upper_length)), (0, rnd(tick_length_lower))))
|
|
1165
|
+
self.add_text_in_group(
|
|
1166
|
+
lab,
|
|
1167
|
+
tick,
|
|
1168
|
+
class_="lab",
|
|
1169
|
+
# place origin at the bottom of the tick plus a single px space
|
|
1170
|
+
pos=(0, tick_length_lower + 1),
|
|
1171
|
+
)
|
|
1172
|
+
if not self.omit_sites and site_muts is not None:
|
|
1173
|
+
# Add sites as vertical lines with overlaid mutations as upper chevrons
|
|
1174
|
+
for s_id, mutations in site_muts.items():
|
|
1175
|
+
s = self.ts.site(s_id)
|
|
1176
|
+
x = self.x_transform(s.position)
|
|
1177
|
+
site = x_axis.add(
|
|
1178
|
+
dwg.g(
|
|
1179
|
+
class_=f"site s{s.id + self.offsets.site}",
|
|
1180
|
+
transform=f"translate({rnd(x)} {y})",
|
|
1181
|
+
)
|
|
1182
|
+
)
|
|
1183
|
+
site.add(dwg.line((0, 0), (0, rnd(-tick_length_upper)), class_="sym"))
|
|
1184
|
+
for i, m in enumerate(reversed(mutations)):
|
|
1185
|
+
mutation_class = f"mut m{m.id + self.offsets.mutation}"
|
|
1186
|
+
if m.id in self.mutations_outside_tree:
|
|
1187
|
+
mutation_class += " extra"
|
|
1188
|
+
mut = dwg.g(class_=mutation_class)
|
|
1189
|
+
h = -i * 4 - 1.5
|
|
1190
|
+
w = tick_length_upper / 4
|
|
1191
|
+
# Chevron symbol
|
|
1192
|
+
symbol = mut.add(
|
|
1193
|
+
dwg.polyline(
|
|
1194
|
+
[
|
|
1195
|
+
(rnd(w), rnd(h - 2 * w)),
|
|
1196
|
+
(0, rnd(h)),
|
|
1197
|
+
(rnd(-w), rnd(h - 2 * w)),
|
|
1198
|
+
],
|
|
1199
|
+
class_="sym",
|
|
1200
|
+
)
|
|
1201
|
+
)
|
|
1202
|
+
if m.id in self.mutation_titles:
|
|
1203
|
+
symbol.set_desc(title=self.mutation_titles[m.id])
|
|
1204
|
+
site.add(mut)
|
|
1205
|
+
|
|
1206
|
+
def draw_y_axis(
|
|
1207
|
+
self,
|
|
1208
|
+
ticks, # A dict of pos->label
|
|
1209
|
+
upper=None, # In plot coords
|
|
1210
|
+
lower=None, # In plot coords
|
|
1211
|
+
tick_length_outer=default_tick_length, # Positive means towards the outside
|
|
1212
|
+
gridlines=None,
|
|
1213
|
+
side="left", # 'left' or 'right', where the axis is drawn
|
|
1214
|
+
):
|
|
1215
|
+
if not self.y_axis and not self.y_label:
|
|
1216
|
+
return
|
|
1217
|
+
if upper is None:
|
|
1218
|
+
upper = self.plotbox.top
|
|
1219
|
+
if lower is None:
|
|
1220
|
+
lower = self.plotbox.bottom
|
|
1221
|
+
dwg = self.drawing
|
|
1222
|
+
if side == "left":
|
|
1223
|
+
x = rnd(self.y_axis_offset)
|
|
1224
|
+
width = self.plotbox.right - x
|
|
1225
|
+
direction = -1
|
|
1226
|
+
text_anchor = "end"
|
|
1227
|
+
pos = (0, (upper + lower) / 2)
|
|
1228
|
+
transform = "translate(11) rotate(-90)"
|
|
1229
|
+
else:
|
|
1230
|
+
x = rnd(self.plotbox.max_x - self.y_axis_offset)
|
|
1231
|
+
width = x - self.plotbox.left
|
|
1232
|
+
direction = 1
|
|
1233
|
+
text_anchor = "start"
|
|
1234
|
+
pos = (self.plotbox.max_x, (upper + lower) / 2)
|
|
1235
|
+
transform = "translate(-11) rotate(90)"
|
|
1236
|
+
axes = self.get_axes()
|
|
1237
|
+
y_axis = axes.add(dwg.g(class_="y-axis"))
|
|
1238
|
+
if self.y_label:
|
|
1239
|
+
self.add_text_in_group(
|
|
1240
|
+
self.y_label,
|
|
1241
|
+
y_axis,
|
|
1242
|
+
pos=pos,
|
|
1243
|
+
group_class="title",
|
|
1244
|
+
class_="lab",
|
|
1245
|
+
text_anchor="middle",
|
|
1246
|
+
transform=transform,
|
|
1247
|
+
)
|
|
1248
|
+
if self.y_axis:
|
|
1249
|
+
y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)), class_="ax-line"))
|
|
1250
|
+
ticks_group = y_axis.add(dwg.g(class_="ticks"))
|
|
1251
|
+
tick_outside_axis = {}
|
|
1252
|
+
for y, label in ticks.items():
|
|
1253
|
+
y_pos = self.timescaling.transform(y)
|
|
1254
|
+
if y_pos > lower or y_pos < upper: # nb lower > upper in SVG coords
|
|
1255
|
+
tick_outside_axis[y] = label
|
|
1256
|
+
tick = ticks_group.add(
|
|
1257
|
+
dwg.g(class_="tick", transform=f"translate({x} {rnd(y_pos)})")
|
|
1258
|
+
)
|
|
1259
|
+
if gridlines:
|
|
1260
|
+
tick.add(dwg.line((0, 0), (rnd(width), 0), class_="grid"))
|
|
1261
|
+
tick.add(dwg.line((0, 0), (rnd(direction * tick_length_outer), 0)))
|
|
1262
|
+
self.add_text_in_group(
|
|
1263
|
+
# place the origin at the left of the tickmark plus a single px space
|
|
1264
|
+
label,
|
|
1265
|
+
tick,
|
|
1266
|
+
pos=(rnd(direction * (tick_length_outer + 1)), 0),
|
|
1267
|
+
class_="lab",
|
|
1268
|
+
text_anchor=text_anchor,
|
|
1269
|
+
)
|
|
1270
|
+
if len(tick_outside_axis) > 0:
|
|
1271
|
+
logging.warning(
|
|
1272
|
+
f"Ticks {tick_outside_axis} lie outside the plotted axis"
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
def shade_background(
|
|
1276
|
+
self,
|
|
1277
|
+
breaks,
|
|
1278
|
+
tick_length_lower,
|
|
1279
|
+
tree_width=None,
|
|
1280
|
+
bottom_padding=None,
|
|
1281
|
+
):
|
|
1282
|
+
if not self.x_axis:
|
|
1283
|
+
return
|
|
1284
|
+
if tree_width is None:
|
|
1285
|
+
tree_width = self.plotbox.width
|
|
1286
|
+
if bottom_padding is None:
|
|
1287
|
+
bottom_padding = self.plotbox.pad_bottom
|
|
1288
|
+
plot_breaks = self.x_transform(np.array(breaks))
|
|
1289
|
+
dwg = self.drawing
|
|
1290
|
+
|
|
1291
|
+
# For tree sequences, we need to add on the background shaded regions
|
|
1292
|
+
self.root_groups["background"] = self.dwg_base.add(dwg.g(class_="background"))
|
|
1293
|
+
y = self.image_size[1] - self.x_axis_offset - self.plotbox.top
|
|
1294
|
+
for i in range(1, len(breaks)):
|
|
1295
|
+
break_x = plot_breaks[i]
|
|
1296
|
+
prev_break_x = plot_breaks[i - 1]
|
|
1297
|
+
tree_x = i * tree_width + self.plotbox.left
|
|
1298
|
+
prev_tree_x = (i - 1) * tree_width + self.plotbox.left
|
|
1299
|
+
# Shift diagonal lines between tree & axis into the treebox a little
|
|
1300
|
+
diag_height = y - (self.image_size[1] - bottom_padding) + self.plotbox.top
|
|
1301
|
+
self.root_groups["background"].add(
|
|
1302
|
+
# NB: the path below draws straight diagonal lines between the tree boxes
|
|
1303
|
+
# and the X axis. An alternative implementation using bezier curves could
|
|
1304
|
+
# substitute the following for lines 2 and 4 of the path spec string
|
|
1305
|
+
# "l0,{box_h:g} c0,{diag_h} {rdiag_x},0 {rdiag_x},{diag_h} "
|
|
1306
|
+
# "c0,-{diag_h} {ldiag_x},0 {ldiag_x},-{diag_h} l0,-{box_h:g}z"
|
|
1307
|
+
dwg.path(
|
|
1308
|
+
"M{start_x:g},{top:g} l{box_w:g},0 " # Top left to top right of tree
|
|
1309
|
+
"l0,{box_h:g} l{rdiag_x:g},{diag_h:g} " # Down to axis
|
|
1310
|
+
"l0,{tick_h:g} l{ax_x:g},0 l0,-{tick_h:g} " # Between axis ticks
|
|
1311
|
+
"l{ldiag_x:g},-{diag_h:g} l0,-{box_h:g}z".format( # Up from axis
|
|
1312
|
+
top=rnd(self.plotbox.top),
|
|
1313
|
+
start_x=rnd(prev_tree_x),
|
|
1314
|
+
box_w=rnd(tree_x - prev_tree_x),
|
|
1315
|
+
box_h=rnd(y - diag_height),
|
|
1316
|
+
rdiag_x=rnd(break_x - tree_x),
|
|
1317
|
+
diag_h=rnd(diag_height),
|
|
1318
|
+
tick_h=rnd(tick_length_lower),
|
|
1319
|
+
ax_x=rnd(prev_break_x - break_x),
|
|
1320
|
+
ldiag_x=rnd(rnd(prev_tree_x) - rnd(prev_break_x)),
|
|
1321
|
+
)
|
|
1322
|
+
)
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
def x_transform(self, x):
|
|
1326
|
+
raise NotImplementedError(
|
|
1327
|
+
"No transform func defined for genome pos -> plot coords"
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
class SvgTreeSequence(SvgAxisPlot):
|
|
1332
|
+
"""
|
|
1333
|
+
A class to draw a tree sequence in SVG format.
|
|
1334
|
+
|
|
1335
|
+
See :meth:`TreeSequence.draw_svg` for a description of usage and parameters.
|
|
1336
|
+
"""
|
|
1337
|
+
|
|
1338
|
+
def __init__(
|
|
1339
|
+
self,
|
|
1340
|
+
ts,
|
|
1341
|
+
size,
|
|
1342
|
+
x_scale,
|
|
1343
|
+
time_scale,
|
|
1344
|
+
node_labels,
|
|
1345
|
+
mutation_labels,
|
|
1346
|
+
root_svg_attributes,
|
|
1347
|
+
style,
|
|
1348
|
+
order,
|
|
1349
|
+
force_root_branch,
|
|
1350
|
+
symbol_size,
|
|
1351
|
+
x_axis,
|
|
1352
|
+
y_axis,
|
|
1353
|
+
x_label,
|
|
1354
|
+
y_label,
|
|
1355
|
+
y_ticks,
|
|
1356
|
+
x_regions=None,
|
|
1357
|
+
y_gridlines=None,
|
|
1358
|
+
x_lim=None,
|
|
1359
|
+
max_time=None,
|
|
1360
|
+
min_time=None,
|
|
1361
|
+
node_attrs=None,
|
|
1362
|
+
mutation_attrs=None,
|
|
1363
|
+
edge_attrs=None,
|
|
1364
|
+
node_label_attrs=None,
|
|
1365
|
+
mutation_label_attrs=None,
|
|
1366
|
+
node_titles=None,
|
|
1367
|
+
mutation_titles=None,
|
|
1368
|
+
tree_height_scale=None,
|
|
1369
|
+
max_tree_height=None,
|
|
1370
|
+
max_num_trees=None,
|
|
1371
|
+
title=None,
|
|
1372
|
+
preamble=None,
|
|
1373
|
+
**kwargs,
|
|
1374
|
+
):
|
|
1375
|
+
if max_time is None and max_tree_height is not None:
|
|
1376
|
+
max_time = max_tree_height
|
|
1377
|
+
# Deprecated in 0.3.6
|
|
1378
|
+
warnings.warn(
|
|
1379
|
+
"max_tree_height is deprecated; use max_time instead",
|
|
1380
|
+
FutureWarning,
|
|
1381
|
+
stacklevel=4,
|
|
1382
|
+
)
|
|
1383
|
+
if time_scale is None and tree_height_scale is not None:
|
|
1384
|
+
time_scale = tree_height_scale
|
|
1385
|
+
# Deprecated in 0.3.6
|
|
1386
|
+
warnings.warn(
|
|
1387
|
+
"tree_height_scale is deprecated; use time_scale instead",
|
|
1388
|
+
FutureWarning,
|
|
1389
|
+
stacklevel=4,
|
|
1390
|
+
)
|
|
1391
|
+
x_lim = check_x_lim(x_lim, max_x=ts.sequence_length)
|
|
1392
|
+
ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1], max_num_trees)
|
|
1393
|
+
|
|
1394
|
+
use_tree = self.tree_status & OMIT == 0
|
|
1395
|
+
use_skipped = np.append(np.diff(self.tree_status & OMIT_MIDDLE == 0) == 1, 0)
|
|
1396
|
+
num_plotboxes = np.sum(np.logical_or(use_tree, use_skipped))
|
|
1397
|
+
if size is None:
|
|
1398
|
+
size = (self.default_width * int(num_plotboxes), self.default_height)
|
|
1399
|
+
if max_time is None:
|
|
1400
|
+
max_time = "ts"
|
|
1401
|
+
if min_time is None:
|
|
1402
|
+
min_time = "ts"
|
|
1403
|
+
# X axis shown by default
|
|
1404
|
+
if x_axis is None:
|
|
1405
|
+
x_axis = True
|
|
1406
|
+
super().__init__(
|
|
1407
|
+
ts,
|
|
1408
|
+
size,
|
|
1409
|
+
root_svg_attributes,
|
|
1410
|
+
style,
|
|
1411
|
+
svg_class="tree-sequence",
|
|
1412
|
+
time_scale=time_scale,
|
|
1413
|
+
x_axis=x_axis,
|
|
1414
|
+
y_axis=y_axis,
|
|
1415
|
+
x_label=x_label,
|
|
1416
|
+
y_label=y_label,
|
|
1417
|
+
offsets=offsets,
|
|
1418
|
+
mutation_titles=mutation_titles,
|
|
1419
|
+
preamble=preamble,
|
|
1420
|
+
**kwargs,
|
|
1421
|
+
)
|
|
1422
|
+
x_scale = check_x_scale(x_scale)
|
|
1423
|
+
order = check_order(order)
|
|
1424
|
+
if node_labels is None:
|
|
1425
|
+
node_labels = {u: str(u) for u in range(ts.num_nodes)}
|
|
1426
|
+
if force_root_branch is None:
|
|
1427
|
+
force_root_branch = any(
|
|
1428
|
+
any(tree.parent(mut.node) == NULL for mut in tree.mutations())
|
|
1429
|
+
for tree, use in zip(ts.trees(), use_tree)
|
|
1430
|
+
if use
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# TODO add general padding arguments following matplotlib's terminology.
|
|
1434
|
+
self.set_spacing(
|
|
1435
|
+
top=0 if title is None else self.line_height, left=20, bottom=10, right=20
|
|
1436
|
+
)
|
|
1437
|
+
subplot_size = (self.plotbox.width / num_plotboxes, self.plotbox.height)
|
|
1438
|
+
subplots = []
|
|
1439
|
+
for tree, use, summary in zip(ts.trees(), use_tree, use_skipped):
|
|
1440
|
+
if use:
|
|
1441
|
+
subplots.append(
|
|
1442
|
+
SvgTree(
|
|
1443
|
+
tree,
|
|
1444
|
+
size=subplot_size,
|
|
1445
|
+
time_scale=time_scale,
|
|
1446
|
+
node_labels=node_labels,
|
|
1447
|
+
mutation_labels=mutation_labels,
|
|
1448
|
+
node_titles=node_titles,
|
|
1449
|
+
mutation_titles=mutation_titles,
|
|
1450
|
+
order=order,
|
|
1451
|
+
force_root_branch=force_root_branch,
|
|
1452
|
+
symbol_size=symbol_size,
|
|
1453
|
+
max_time=max_time,
|
|
1454
|
+
min_time=min_time,
|
|
1455
|
+
node_attrs=node_attrs,
|
|
1456
|
+
mutation_attrs=mutation_attrs,
|
|
1457
|
+
edge_attrs=edge_attrs,
|
|
1458
|
+
node_label_attrs=node_label_attrs,
|
|
1459
|
+
mutation_label_attrs=mutation_label_attrs,
|
|
1460
|
+
offsets=offsets,
|
|
1461
|
+
# Do not plot axes on these subplots
|
|
1462
|
+
**kwargs, # pass though e.g. debug boxes
|
|
1463
|
+
)
|
|
1464
|
+
)
|
|
1465
|
+
last_used_index = tree.index
|
|
1466
|
+
elif summary:
|
|
1467
|
+
subplots.append(
|
|
1468
|
+
SvgSkippedPlot(
|
|
1469
|
+
size=subplot_size, num_skipped=tree.index - last_used_index
|
|
1470
|
+
)
|
|
1471
|
+
)
|
|
1472
|
+
y = self.plotbox.top
|
|
1473
|
+
if title is not None:
|
|
1474
|
+
self.add_text_in_group(
|
|
1475
|
+
title,
|
|
1476
|
+
self.drawing,
|
|
1477
|
+
pos=(self.plotbox.max_x / 2, 0),
|
|
1478
|
+
dominant_baseline="hanging",
|
|
1479
|
+
group_class="title",
|
|
1480
|
+
text_anchor="middle",
|
|
1481
|
+
)
|
|
1482
|
+
self.tree_plotbox = subplots[0].plotbox
|
|
1483
|
+
tree_is_used, breaks, skipbreaks = self.find_used_trees()
|
|
1484
|
+
self.draw_x_axis(
|
|
1485
|
+
x_scale,
|
|
1486
|
+
tree_is_used,
|
|
1487
|
+
breaks,
|
|
1488
|
+
skipbreaks,
|
|
1489
|
+
tick_length_lower=self.default_tick_length, # TODO - parameterize
|
|
1490
|
+
tick_length_upper=self.default_tick_length_site, # TODO - parameterize
|
|
1491
|
+
x_regions=x_regions,
|
|
1492
|
+
)
|
|
1493
|
+
y_low = self.tree_plotbox.bottom
|
|
1494
|
+
if y_axis is not None:
|
|
1495
|
+
tscales = {s.timescaling for s in subplots if s.timescaling}
|
|
1496
|
+
if len(tscales) > 1:
|
|
1497
|
+
raise ValueError(
|
|
1498
|
+
"Can't draw a tree sequence Y axis if trees vary in timescale"
|
|
1499
|
+
)
|
|
1500
|
+
self.timescaling = tscales.pop()
|
|
1501
|
+
y_low = self.timescaling.transform(self.timescaling.min_time)
|
|
1502
|
+
if y_ticks is None:
|
|
1503
|
+
used_nodes = edge_and_sample_nodes(ts, breaks[skipbreaks])
|
|
1504
|
+
y_ticks = np.unique(ts.nodes_time[used_nodes])
|
|
1505
|
+
if self.time_scale == "rank":
|
|
1506
|
+
# Ticks labelled by time not rank
|
|
1507
|
+
y_ticks = dict(enumerate(y_ticks))
|
|
1508
|
+
|
|
1509
|
+
self.draw_y_axis(
|
|
1510
|
+
ticks=check_y_ticks(y_ticks),
|
|
1511
|
+
upper=self.tree_plotbox.top,
|
|
1512
|
+
lower=y_low,
|
|
1513
|
+
tick_length_outer=self.default_tick_length,
|
|
1514
|
+
gridlines=y_gridlines,
|
|
1515
|
+
side="right" if y_axis == "right" else "left",
|
|
1516
|
+
)
|
|
1517
|
+
|
|
1518
|
+
subplot_x = self.plotbox.left
|
|
1519
|
+
container = self.get_plotbox() # Top-level TS plotbox contains all trees
|
|
1520
|
+
container["class"] = container["class"] + " trees"
|
|
1521
|
+
for subplot in subplots:
|
|
1522
|
+
svg_subplot = container.add(
|
|
1523
|
+
self.drawing.g(
|
|
1524
|
+
class_=subplot.svg_class,
|
|
1525
|
+
transform=f"translate({rnd(subplot_x)} {y})",
|
|
1526
|
+
)
|
|
1527
|
+
)
|
|
1528
|
+
for svg_items in subplot.root_groups.values():
|
|
1529
|
+
svg_subplot.add(svg_items)
|
|
1530
|
+
subplot_x += subplot.image_size[0]
|
|
1531
|
+
|
|
1532
|
+
def find_used_trees(self):
|
|
1533
|
+
"""
|
|
1534
|
+
Return a boolean array of which trees are actually plotted,
|
|
1535
|
+
a list of which breakpoints are used to transition between plotted trees,
|
|
1536
|
+
and a 2 x n array (often n=0) of indexes into these breakpoints delimiting
|
|
1537
|
+
the regions that should be plotted as "skipped"
|
|
1538
|
+
"""
|
|
1539
|
+
tree_is_used = (self.tree_status & OMIT) != OMIT
|
|
1540
|
+
break_used_as_tree_left = np.append(tree_is_used, False)
|
|
1541
|
+
break_used_as_tree_right = np.insert(tree_is_used, 0, False)
|
|
1542
|
+
break_used = np.logical_or(break_used_as_tree_left, break_used_as_tree_right)
|
|
1543
|
+
all_breaks = self.ts.breakpoints(True)
|
|
1544
|
+
used_breaks = all_breaks[break_used]
|
|
1545
|
+
mark_skip_transitions = np.concatenate(
|
|
1546
|
+
([False], np.diff(self.tree_status & OMIT_MIDDLE) != 0, [False])
|
|
1547
|
+
)
|
|
1548
|
+
skipregion_indexes = np.where(mark_skip_transitions[break_used])[0]
|
|
1549
|
+
assert len(skipregion_indexes) % 2 == 0 # all skipped regions have start, end
|
|
1550
|
+
return tree_is_used, used_breaks, skipregion_indexes.reshape((-1, 2))
|
|
1551
|
+
|
|
1552
|
+
def draw_x_axis(
|
|
1553
|
+
self,
|
|
1554
|
+
x_scale,
|
|
1555
|
+
tree_is_used,
|
|
1556
|
+
breaks,
|
|
1557
|
+
skipbreaks,
|
|
1558
|
+
x_regions,
|
|
1559
|
+
tick_length_lower=SvgAxisPlot.default_tick_length,
|
|
1560
|
+
tick_length_upper=SvgAxisPlot.default_tick_length_site,
|
|
1561
|
+
):
|
|
1562
|
+
"""
|
|
1563
|
+
Add extra functionality to the original draw_x_axis method in SvgAxisPlot,
|
|
1564
|
+
to account for the background shading that is displayed in a tree sequence
|
|
1565
|
+
and in case trees are omitted from the middle of the tree sequence
|
|
1566
|
+
"""
|
|
1567
|
+
if not self.x_axis and not self.x_label:
|
|
1568
|
+
return
|
|
1569
|
+
if x_scale == "physical":
|
|
1570
|
+
# In a tree sequence plot, the x_transform is used for the ticks, background
|
|
1571
|
+
# shading positions, and sites along the x-axis. Each tree will have its own
|
|
1572
|
+
# separate x_transform function for node positions within the tree.
|
|
1573
|
+
|
|
1574
|
+
# For a plot with a break on the x-axis (representing "skipped" trees), the
|
|
1575
|
+
# x_transform is a piecewise function. We need to identify the breakpoints
|
|
1576
|
+
# where the x-scale transitions from the standard scale to the scale(s) used
|
|
1577
|
+
# within a skipped region
|
|
1578
|
+
|
|
1579
|
+
skipregion_plot_width = self.tree_plotbox.width
|
|
1580
|
+
skipregion_span = np.diff(breaks[skipbreaks]).T[0]
|
|
1581
|
+
std_scale = (
|
|
1582
|
+
self.plotbox.width - skipregion_plot_width * len(skipregion_span)
|
|
1583
|
+
) / (breaks[-1] - breaks[0] - np.sum(skipregion_span))
|
|
1584
|
+
skipregion_pos = breaks[skipbreaks].flatten()
|
|
1585
|
+
genome_pos = np.concatenate(([breaks[0]], skipregion_pos, [breaks[-1]]))
|
|
1586
|
+
plot_step = np.full(len(genome_pos) - 1, skipregion_plot_width)
|
|
1587
|
+
plot_step[::2] = std_scale * np.diff(genome_pos)[::2]
|
|
1588
|
+
plot_pos = np.cumsum(np.insert(plot_step, 0, self.plotbox.left))
|
|
1589
|
+
# Convert to slope + intercept form
|
|
1590
|
+
slope = np.diff(plot_pos) / np.diff(genome_pos)
|
|
1591
|
+
intercept = plot_pos[1:] - slope * genome_pos[1:]
|
|
1592
|
+
self.x_transform = lambda y: (
|
|
1593
|
+
y * slope[np.searchsorted(skipregion_pos, y)]
|
|
1594
|
+
+ intercept[np.searchsorted(skipregion_pos, y)]
|
|
1595
|
+
)
|
|
1596
|
+
tick_positions = breaks
|
|
1597
|
+
site_muts = {
|
|
1598
|
+
s.id: s.mutations
|
|
1599
|
+
for tree, use in zip(self.ts.trees(), tree_is_used)
|
|
1600
|
+
for s in tree.sites()
|
|
1601
|
+
if use
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
self.shade_background(
|
|
1605
|
+
breaks,
|
|
1606
|
+
tick_length_lower,
|
|
1607
|
+
self.tree_plotbox.max_x,
|
|
1608
|
+
self.plotbox.pad_bottom + self.tree_plotbox.pad_bottom,
|
|
1609
|
+
)
|
|
1610
|
+
else:
|
|
1611
|
+
# For a treewise plot, the only time the x_transform is used is to apply
|
|
1612
|
+
# to tick positions, so simply use positions 0..num_used_breaks for the
|
|
1613
|
+
# positions, and a simple transform
|
|
1614
|
+
self.x_transform = (
|
|
1615
|
+
lambda x: self.plotbox.left + x / (len(breaks) - 1) * self.plotbox.width
|
|
1616
|
+
)
|
|
1617
|
+
tick_positions = np.arange(len(breaks))
|
|
1618
|
+
|
|
1619
|
+
site_muts = None # It doesn't make sense to plot sites for "treewise" plots
|
|
1620
|
+
tick_length_upper = None # No sites plotted, so use the default upper tick
|
|
1621
|
+
if x_regions is not None and len(x_regions) > 0:
|
|
1622
|
+
raise ValueError("x_regions are not supported for treewise plots")
|
|
1623
|
+
|
|
1624
|
+
# NB: no background shading needed if x_scale is "treewise"
|
|
1625
|
+
|
|
1626
|
+
skipregion_pos = skipbreaks.flatten()
|
|
1627
|
+
|
|
1628
|
+
first_tick = 1 if np.any(self.tree_status[tree_is_used] & LEFT_CLIP) else 0
|
|
1629
|
+
last_tick = -1 if np.any(self.tree_status[tree_is_used] & RIGHT_CLIP) else None
|
|
1630
|
+
|
|
1631
|
+
super().draw_x_axis(
|
|
1632
|
+
tick_positions=tick_positions[first_tick:last_tick],
|
|
1633
|
+
tick_labels=breaks[first_tick:last_tick],
|
|
1634
|
+
tick_length_lower=tick_length_lower,
|
|
1635
|
+
tick_length_upper=tick_length_upper,
|
|
1636
|
+
site_muts=site_muts,
|
|
1637
|
+
alternate_dash_positions=skipregion_pos,
|
|
1638
|
+
x_regions=x_regions,
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
class SvgTree(SvgAxisPlot):
|
|
1643
|
+
"""
|
|
1644
|
+
A class to draw a tree in SVG format.
|
|
1645
|
+
|
|
1646
|
+
See :meth:`Tree.draw_svg` for a description of usage and frequently used parameters.
|
|
1647
|
+
"""
|
|
1648
|
+
|
|
1649
|
+
PolytomyLine = collections.namedtuple(
|
|
1650
|
+
"PolytomyLine", "num_branches, num_samples, line_pos"
|
|
1651
|
+
)
|
|
1652
|
+
margin_left = 20
|
|
1653
|
+
margin_right = 20
|
|
1654
|
+
margin_top = 10 # oldest point is line_height below or 2*line_height if title given
|
|
1655
|
+
margin_bottom = 15 # youngest plot points are line_height above this bottom margin
|
|
1656
|
+
|
|
1657
|
+
def __init__(
|
|
1658
|
+
self,
|
|
1659
|
+
tree,
|
|
1660
|
+
size=None,
|
|
1661
|
+
max_time=None,
|
|
1662
|
+
min_time=None,
|
|
1663
|
+
max_tree_height=None,
|
|
1664
|
+
node_labels=None,
|
|
1665
|
+
mutation_labels=None,
|
|
1666
|
+
node_titles=None,
|
|
1667
|
+
mutation_titles=None,
|
|
1668
|
+
root_svg_attributes=None,
|
|
1669
|
+
style=None,
|
|
1670
|
+
order=None,
|
|
1671
|
+
force_root_branch=None,
|
|
1672
|
+
symbol_size=None,
|
|
1673
|
+
x_axis=None,
|
|
1674
|
+
y_axis=None,
|
|
1675
|
+
x_label=None,
|
|
1676
|
+
y_label=None,
|
|
1677
|
+
title=None,
|
|
1678
|
+
x_regions=None,
|
|
1679
|
+
y_ticks=None,
|
|
1680
|
+
y_gridlines=None,
|
|
1681
|
+
all_edge_mutations=None,
|
|
1682
|
+
time_scale=None,
|
|
1683
|
+
tree_height_scale=None,
|
|
1684
|
+
node_attrs=None,
|
|
1685
|
+
mutation_attrs=None,
|
|
1686
|
+
edge_attrs=None,
|
|
1687
|
+
node_label_attrs=None,
|
|
1688
|
+
mutation_label_attrs=None,
|
|
1689
|
+
offsets=None,
|
|
1690
|
+
omit_sites=None,
|
|
1691
|
+
pack_untracked_polytomies=None,
|
|
1692
|
+
preamble=None,
|
|
1693
|
+
**kwargs,
|
|
1694
|
+
):
|
|
1695
|
+
if max_time is None and max_tree_height is not None:
|
|
1696
|
+
max_time = max_tree_height
|
|
1697
|
+
# Deprecated in 0.3.6
|
|
1698
|
+
warnings.warn(
|
|
1699
|
+
"max_tree_height is deprecated; use max_time instead",
|
|
1700
|
+
FutureWarning,
|
|
1701
|
+
stacklevel=4,
|
|
1702
|
+
)
|
|
1703
|
+
if time_scale is None and tree_height_scale is not None:
|
|
1704
|
+
time_scale = tree_height_scale
|
|
1705
|
+
# Deprecated in 0.3.6
|
|
1706
|
+
warnings.warn(
|
|
1707
|
+
"tree_height_scale is deprecated; use time_scale instead",
|
|
1708
|
+
FutureWarning,
|
|
1709
|
+
stacklevel=4,
|
|
1710
|
+
)
|
|
1711
|
+
if size is None:
|
|
1712
|
+
size = (self.default_width, self.default_height)
|
|
1713
|
+
if symbol_size is None:
|
|
1714
|
+
symbol_size = 6
|
|
1715
|
+
self.symbol_size = symbol_size
|
|
1716
|
+
self.pack_untracked_polytomies = pack_untracked_polytomies
|
|
1717
|
+
ts = tree.tree_sequence
|
|
1718
|
+
tree_index = tree.index
|
|
1719
|
+
if offsets is not None:
|
|
1720
|
+
tree_index += offsets.tree
|
|
1721
|
+
super().__init__(
|
|
1722
|
+
ts,
|
|
1723
|
+
size,
|
|
1724
|
+
root_svg_attributes,
|
|
1725
|
+
style,
|
|
1726
|
+
svg_class=f"tree t{tree_index}",
|
|
1727
|
+
time_scale=time_scale,
|
|
1728
|
+
x_axis=x_axis,
|
|
1729
|
+
y_axis=y_axis,
|
|
1730
|
+
x_label=x_label,
|
|
1731
|
+
y_label=y_label,
|
|
1732
|
+
offsets=offsets,
|
|
1733
|
+
omit_sites=omit_sites,
|
|
1734
|
+
preamble=preamble,
|
|
1735
|
+
**kwargs,
|
|
1736
|
+
)
|
|
1737
|
+
self.tree = tree
|
|
1738
|
+
if order is None or isinstance(order, str):
|
|
1739
|
+
# Can't use the Tree.postorder array as we need minlex
|
|
1740
|
+
self.postorder_nodes = list(tree.nodes(order=check_order(order)))
|
|
1741
|
+
else:
|
|
1742
|
+
# Currently undocumented feature: we can pass a (postorder) list
|
|
1743
|
+
# of nodes to plot, which allows us to draw a subset of nodes, or
|
|
1744
|
+
# stop traversing certain subtrees
|
|
1745
|
+
self.postorder_nodes = order
|
|
1746
|
+
|
|
1747
|
+
# Create some instance variables for later use in plotting
|
|
1748
|
+
self.node_mutations = collections.defaultdict(list)
|
|
1749
|
+
self.edge_attrs = {}
|
|
1750
|
+
self.node_attrs = {}
|
|
1751
|
+
self.node_label_attrs = {}
|
|
1752
|
+
self.mutation_attrs = {}
|
|
1753
|
+
self.mutation_label_attrs = {}
|
|
1754
|
+
self.node_titles = {} if node_titles is None else node_titles
|
|
1755
|
+
self.mutation_titles = {} if mutation_titles is None else mutation_titles
|
|
1756
|
+
self.mutations_over_roots = False
|
|
1757
|
+
# mutations collected per node
|
|
1758
|
+
nodes = set(tree.nodes())
|
|
1759
|
+
unplotted = []
|
|
1760
|
+
if not omit_sites:
|
|
1761
|
+
for site in tree.sites():
|
|
1762
|
+
for mutation in site.mutations:
|
|
1763
|
+
if mutation.node in nodes:
|
|
1764
|
+
self.node_mutations[mutation.node].append(mutation)
|
|
1765
|
+
if tree.parent(mutation.node) == NULL:
|
|
1766
|
+
self.mutations_over_roots = True
|
|
1767
|
+
else:
|
|
1768
|
+
unplotted.append(mutation.id + self.offsets.mutation)
|
|
1769
|
+
if len(unplotted) > 0:
|
|
1770
|
+
warnings.warn(
|
|
1771
|
+
f"Mutations {unplotted} are above nodes which are not present in the "
|
|
1772
|
+
"displayed tree, so are not plotted on the topology.",
|
|
1773
|
+
UserWarning,
|
|
1774
|
+
stacklevel=2,
|
|
1775
|
+
)
|
|
1776
|
+
self.left_extent = tree.interval.left
|
|
1777
|
+
self.right_extent = tree.interval.right
|
|
1778
|
+
if not omit_sites and all_edge_mutations:
|
|
1779
|
+
tree_left = tree.interval.left
|
|
1780
|
+
tree_right = tree.interval.right
|
|
1781
|
+
edge_left = ts.tables.edges.left
|
|
1782
|
+
edge_right = ts.tables.edges.right
|
|
1783
|
+
node_edges = tree.edge_array
|
|
1784
|
+
# whittle mutations down so we only need look at those above the tree nodes
|
|
1785
|
+
mut_t = ts.tables.mutations
|
|
1786
|
+
focal_mutations = np.isin(mut_t.node, np.fromiter(nodes, mut_t.node.dtype))
|
|
1787
|
+
mutation_nodes = mut_t.node[focal_mutations]
|
|
1788
|
+
mutation_positions = ts.tables.sites.position[mut_t.site][focal_mutations]
|
|
1789
|
+
mutation_ids = np.arange(ts.num_mutations, dtype=int)[focal_mutations]
|
|
1790
|
+
for m_id, node, pos in zip(
|
|
1791
|
+
mutation_ids, mutation_nodes, mutation_positions
|
|
1792
|
+
):
|
|
1793
|
+
curr_edge = node_edges[node]
|
|
1794
|
+
if curr_edge >= 0:
|
|
1795
|
+
if (
|
|
1796
|
+
edge_left[curr_edge] <= pos < tree_left
|
|
1797
|
+
): # Mutation on this edge but to left of plotted tree
|
|
1798
|
+
self.node_mutations[node].append(ts.mutation(m_id))
|
|
1799
|
+
self.mutations_outside_tree.add(m_id)
|
|
1800
|
+
self.left_extent = min(self.left_extent, pos)
|
|
1801
|
+
elif (
|
|
1802
|
+
tree_right <= pos < edge_right[curr_edge]
|
|
1803
|
+
): # Mutation on this edge but to right of plotted tree
|
|
1804
|
+
self.node_mutations[node].append(ts.mutation(m_id))
|
|
1805
|
+
self.mutations_outside_tree.add(m_id)
|
|
1806
|
+
self.right_extent = max(self.right_extent, pos)
|
|
1807
|
+
if self.right_extent != tree.interval.right:
|
|
1808
|
+
# Use nextafter so extent of plotting incorporates the mutation
|
|
1809
|
+
self.right_extent = np.nextafter(
|
|
1810
|
+
self.right_extent, self.right_extent + 1
|
|
1811
|
+
)
|
|
1812
|
+
# attributes for symbols
|
|
1813
|
+
half_symbol_size = f"{rnd(symbol_size / 2):g}"
|
|
1814
|
+
symbol_size = f"{rnd(symbol_size):g}"
|
|
1815
|
+
for u in tree.nodes():
|
|
1816
|
+
self.edge_attrs[u] = {}
|
|
1817
|
+
if edge_attrs is not None and u in edge_attrs:
|
|
1818
|
+
self.edge_attrs[u].update(edge_attrs[u])
|
|
1819
|
+
if tree.is_sample(u):
|
|
1820
|
+
# a square: set bespoke svgwrite params
|
|
1821
|
+
self.node_attrs[u] = {
|
|
1822
|
+
"size": (symbol_size,) * 2,
|
|
1823
|
+
"insert": ("-" + half_symbol_size,) * 2,
|
|
1824
|
+
}
|
|
1825
|
+
else:
|
|
1826
|
+
# a circle: set bespoke svgwrite param `centre` and default radius
|
|
1827
|
+
self.node_attrs[u] = {"center": (0, 0), "r": half_symbol_size}
|
|
1828
|
+
if node_attrs is not None and u in node_attrs:
|
|
1829
|
+
self.node_attrs[u].update(node_attrs[u])
|
|
1830
|
+
add_class(self.node_attrs[u], "sym") # class 'sym' for symbol
|
|
1831
|
+
label = ""
|
|
1832
|
+
if node_labels is None:
|
|
1833
|
+
label = str(u)
|
|
1834
|
+
elif u in node_labels:
|
|
1835
|
+
label = str(node_labels[u])
|
|
1836
|
+
self.node_label_attrs[u] = {"text": label}
|
|
1837
|
+
add_class(self.node_label_attrs[u], "lab") # class 'lab' for label
|
|
1838
|
+
if node_label_attrs is not None and u in node_label_attrs:
|
|
1839
|
+
self.node_label_attrs[u].update(node_label_attrs[u])
|
|
1840
|
+
for _, mutations in self.node_mutations.items():
|
|
1841
|
+
for mutation in mutations:
|
|
1842
|
+
m = mutation.id + self.offsets.mutation
|
|
1843
|
+
# We need to offset the mutation symbol so that it's centred
|
|
1844
|
+
self.mutation_attrs[m] = {
|
|
1845
|
+
"d": "M -{0},-{0} l {1},{1} M -{0},{0} l {1},-{1}".format(
|
|
1846
|
+
half_symbol_size, symbol_size
|
|
1847
|
+
)
|
|
1848
|
+
}
|
|
1849
|
+
if mutation_attrs is not None and m in mutation_attrs:
|
|
1850
|
+
self.mutation_attrs[m].update(mutation_attrs[m])
|
|
1851
|
+
add_class(self.mutation_attrs[m], "sym") # class 'sym' for symbol
|
|
1852
|
+
label = ""
|
|
1853
|
+
if mutation_labels is None:
|
|
1854
|
+
label = str(m)
|
|
1855
|
+
elif m in mutation_labels:
|
|
1856
|
+
label = str(mutation_labels[m])
|
|
1857
|
+
self.mutation_label_attrs[m] = {"text": label}
|
|
1858
|
+
if mutation_label_attrs is not None and m in mutation_label_attrs:
|
|
1859
|
+
self.mutation_label_attrs[m].update(mutation_label_attrs[m])
|
|
1860
|
+
add_class(self.mutation_label_attrs[m], "lab")
|
|
1861
|
+
|
|
1862
|
+
self.set_spacing(
|
|
1863
|
+
top=self.margin_top + (0 if title is None else self.line_height),
|
|
1864
|
+
left=self.margin_left,
|
|
1865
|
+
bottom=self.margin_bottom,
|
|
1866
|
+
right=self.margin_right,
|
|
1867
|
+
)
|
|
1868
|
+
if title is not None:
|
|
1869
|
+
self.add_text_in_group(
|
|
1870
|
+
title,
|
|
1871
|
+
self.drawing,
|
|
1872
|
+
pos=(self.plotbox.max_x / 2, 0),
|
|
1873
|
+
dominant_baseline="hanging",
|
|
1874
|
+
group_class="title",
|
|
1875
|
+
text_anchor="middle",
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
self.assign_x_coordinates()
|
|
1879
|
+
self.assign_y_coordinates(max_time, min_time, force_root_branch)
|
|
1880
|
+
tick_length_lower = self.default_tick_length # TODO - parameterize
|
|
1881
|
+
tick_length_upper = self.default_tick_length_site # TODO - parameterize
|
|
1882
|
+
if all_edge_mutations:
|
|
1883
|
+
self.shade_background(tree.interval, tick_length_lower)
|
|
1884
|
+
|
|
1885
|
+
first_site, last_site = np.searchsorted(
|
|
1886
|
+
self.ts.tables.sites.position, [self.left_extent, self.right_extent]
|
|
1887
|
+
)
|
|
1888
|
+
site_muts = {site_id: [] for site_id in range(first_site, last_site)}
|
|
1889
|
+
# Only use mutations plotted on the tree (not necessarily all at the site)
|
|
1890
|
+
for muts in self.node_mutations.values():
|
|
1891
|
+
for mut in muts:
|
|
1892
|
+
site_muts[mut.site].append(mut)
|
|
1893
|
+
|
|
1894
|
+
self.draw_x_axis(
|
|
1895
|
+
tick_positions=np.array(tree.interval),
|
|
1896
|
+
tick_length_lower=tick_length_lower,
|
|
1897
|
+
tick_length_upper=tick_length_upper,
|
|
1898
|
+
site_muts=site_muts,
|
|
1899
|
+
x_regions=x_regions,
|
|
1900
|
+
)
|
|
1901
|
+
if y_ticks is None:
|
|
1902
|
+
y_ticks = {h: ts.node(u).time for u, h in sorted(self.node_height.items())}
|
|
1903
|
+
|
|
1904
|
+
self.draw_y_axis(
|
|
1905
|
+
ticks=check_y_ticks(y_ticks),
|
|
1906
|
+
lower=self.timescaling.transform(self.timescaling.min_time),
|
|
1907
|
+
tick_length_outer=self.default_tick_length,
|
|
1908
|
+
gridlines=y_gridlines,
|
|
1909
|
+
side="right" if y_axis == "right" else "left",
|
|
1910
|
+
)
|
|
1911
|
+
self.draw_tree()
|
|
1912
|
+
|
|
1913
|
+
def process_mutations_over_node(self, u, low_bound, high_bound, ignore_times=False):
|
|
1914
|
+
"""
|
|
1915
|
+
Sort the self.node_mutations array for a given node ``u`` in reverse time order.
|
|
1916
|
+
The main complication is with UNKNOWN_TIME values: we replace these with times
|
|
1917
|
+
spaced between the low & high bounds (this is always done if ignore_times=True).
|
|
1918
|
+
We do not currently allow a mix of known & unknown mutation times in a tree
|
|
1919
|
+
sequence, which makes the logic easy. If we were to allow it, more complex
|
|
1920
|
+
logic can be neatly encapsulated in this method.
|
|
1921
|
+
"""
|
|
1922
|
+
mutations = self.node_mutations[u]
|
|
1923
|
+
time_unknown = [util.is_unknown_time(m.time) for m in mutations]
|
|
1924
|
+
if all(time_unknown) or ignore_times is True:
|
|
1925
|
+
# sort by site then within site by parent: will end up with oldest first
|
|
1926
|
+
mutations.sort(key=operator.attrgetter("site", "parent"))
|
|
1927
|
+
diff = high_bound - low_bound
|
|
1928
|
+
for i in range(len(mutations)):
|
|
1929
|
+
mutations[i].time = high_bound - diff * (i + 1) / (len(mutations) + 1)
|
|
1930
|
+
else:
|
|
1931
|
+
assert not any(time_unknown)
|
|
1932
|
+
mutations.sort(key=operator.attrgetter("time"), reverse=True)
|
|
1933
|
+
|
|
1934
|
+
def assign_y_coordinates(
|
|
1935
|
+
self,
|
|
1936
|
+
max_time,
|
|
1937
|
+
min_time,
|
|
1938
|
+
force_root_branch,
|
|
1939
|
+
bottom_space=SvgAxisPlot.line_height,
|
|
1940
|
+
top_space=SvgAxisPlot.line_height,
|
|
1941
|
+
):
|
|
1942
|
+
"""
|
|
1943
|
+
Create a self.node_height dict, a self.timescaling instance and
|
|
1944
|
+
self.min_root_branch_plot_length for use in plotting. Allow extra space within
|
|
1945
|
+
the plotbox, at the bottom for leaf labels, and (potentially, if no root
|
|
1946
|
+
branches are plotted) above the topmost root node for root labels.
|
|
1947
|
+
"""
|
|
1948
|
+
max_time = check_max_time(max_time, self.time_scale != "rank")
|
|
1949
|
+
min_time = check_min_time(min_time, self.time_scale != "rank")
|
|
1950
|
+
node_time = self.ts.nodes_time
|
|
1951
|
+
mut_time = self.ts.mutations_time
|
|
1952
|
+
root_branch_len = 0
|
|
1953
|
+
if self.time_scale == "rank":
|
|
1954
|
+
t = np.zeros_like(node_time)
|
|
1955
|
+
if max_time == "tree":
|
|
1956
|
+
# We only rank the times within the tree in this case.
|
|
1957
|
+
for u in self.node_x_coord.keys():
|
|
1958
|
+
t[u] = node_time[u]
|
|
1959
|
+
else:
|
|
1960
|
+
# only rank the nodes that are actually referenced in the edge table
|
|
1961
|
+
# (non-referenced nodes could occur if the user specifies x_lim values)
|
|
1962
|
+
# However, we do include nodes in trees that have been skipped
|
|
1963
|
+
use_time = edge_and_sample_nodes(self.ts)
|
|
1964
|
+
t[use_time] = node_time[use_time]
|
|
1965
|
+
node_time = t
|
|
1966
|
+
times = np.unique(node_time[node_time <= self.ts.max_root_time])
|
|
1967
|
+
max_node_height = len(times)
|
|
1968
|
+
depth = {t: j for j, t in enumerate(times)}
|
|
1969
|
+
if self.mutations_over_roots or force_root_branch:
|
|
1970
|
+
root_branch_len = 1 # Will get scaled later
|
|
1971
|
+
max_time = max(depth.values()) + root_branch_len
|
|
1972
|
+
if min_time in (None, "tree", "ts"):
|
|
1973
|
+
assert min(depth.values()) == 0
|
|
1974
|
+
min_time = 0
|
|
1975
|
+
# In pathological cases, all the nodes are at the same time
|
|
1976
|
+
if max_time == min_time:
|
|
1977
|
+
max_time = min_time + 1
|
|
1978
|
+
self.node_height = {
|
|
1979
|
+
u: depth[node_time[u]] for u in self.node_x_coord.keys()
|
|
1980
|
+
}
|
|
1981
|
+
for u in self.node_mutations.keys():
|
|
1982
|
+
if u in self.node_height:
|
|
1983
|
+
parent = self.tree.parent(u)
|
|
1984
|
+
if parent == NULL:
|
|
1985
|
+
top = self.node_height[u] + root_branch_len
|
|
1986
|
+
else:
|
|
1987
|
+
top = depth[node_time[parent]]
|
|
1988
|
+
self.process_mutations_over_node(
|
|
1989
|
+
u, self.node_height[u], top, ignore_times=True
|
|
1990
|
+
)
|
|
1991
|
+
else:
|
|
1992
|
+
assert self.time_scale in ["time", "log_time"]
|
|
1993
|
+
self.node_height = {u: node_time[u] for u in self.node_x_coord.keys()}
|
|
1994
|
+
if max_time == "tree":
|
|
1995
|
+
max_node_height = max(self.node_height.values())
|
|
1996
|
+
max_mut_height = np.nanmax(
|
|
1997
|
+
[0] + [mut.time for m in self.node_mutations.values() for mut in m]
|
|
1998
|
+
)
|
|
1999
|
+
max_time = max(max_node_height, max_mut_height) # Reuse variable
|
|
2000
|
+
elif max_time == "ts":
|
|
2001
|
+
max_node_height = self.ts.max_root_time
|
|
2002
|
+
max_mut_height = np.nanmax(np.append(mut_time, 0))
|
|
2003
|
+
max_time = max(max_node_height, max_mut_height) # Reuse variable
|
|
2004
|
+
else:
|
|
2005
|
+
max_node_height = max_time
|
|
2006
|
+
if min_time == "tree":
|
|
2007
|
+
min_time = min(self.node_height.values())
|
|
2008
|
+
# don't need to check mutation times, as they must be above a node
|
|
2009
|
+
elif min_time == "ts":
|
|
2010
|
+
min_time = np.min(self.ts.nodes_time[edge_and_sample_nodes(self.ts)])
|
|
2011
|
+
# In pathological cases, all the nodes are at the same time
|
|
2012
|
+
if min_time == max_time:
|
|
2013
|
+
max_time = min_time + 1
|
|
2014
|
+
if self.mutations_over_roots or force_root_branch:
|
|
2015
|
+
# Define a minimum root branch length, after transformation if necessary
|
|
2016
|
+
if self.time_scale != "log_time":
|
|
2017
|
+
root_branch_len = (max_time - min_time) * self.root_branch_fraction
|
|
2018
|
+
else:
|
|
2019
|
+
max_plot_y = np.log(max_time + 1)
|
|
2020
|
+
diff_plot_y = max_plot_y - np.log(min_time + 1)
|
|
2021
|
+
root_plot_y = max_plot_y + diff_plot_y * self.root_branch_fraction
|
|
2022
|
+
root_branch_len = np.exp(root_plot_y) - 1 - max_time
|
|
2023
|
+
# If necessary, allow for this extra branch in max_time
|
|
2024
|
+
if max_node_height + root_branch_len > max_time:
|
|
2025
|
+
max_time = max_node_height + root_branch_len
|
|
2026
|
+
for u in self.node_mutations.keys():
|
|
2027
|
+
if u in self.node_height:
|
|
2028
|
+
parent = self.tree.parent(u)
|
|
2029
|
+
if parent == NULL:
|
|
2030
|
+
# This is a root: if muts have no times we specify an upper time
|
|
2031
|
+
top = self.node_height[u] + root_branch_len
|
|
2032
|
+
else:
|
|
2033
|
+
top = node_time[parent]
|
|
2034
|
+
self.process_mutations_over_node(u, self.node_height[u], top)
|
|
2035
|
+
|
|
2036
|
+
assert float(max_time) == max_time
|
|
2037
|
+
assert float(min_time) == min_time
|
|
2038
|
+
# Add extra space above the top and below the bottom of the tree to keep the
|
|
2039
|
+
# node labels within the plotbox (but top label space not needed if the
|
|
2040
|
+
# existence of a root branch pushes the whole tree + labels downwards anyway)
|
|
2041
|
+
top_space = 0 if root_branch_len > 0 else top_space
|
|
2042
|
+
self.timescaling = Timescaling(
|
|
2043
|
+
max_time=max_time,
|
|
2044
|
+
min_time=min_time,
|
|
2045
|
+
plot_min=self.plotbox.height + self.plotbox.top - bottom_space,
|
|
2046
|
+
plot_range=self.plotbox.height - top_space - bottom_space,
|
|
2047
|
+
use_log_transform=(self.time_scale == "log_time"),
|
|
2048
|
+
)
|
|
2049
|
+
|
|
2050
|
+
# Calculate default root branch length to use (in plot coords). This is a
|
|
2051
|
+
# minimum, as branches with deep root mutations could be longer
|
|
2052
|
+
self.min_root_branch_plot_length = self.timescaling.transform(
|
|
2053
|
+
self.timescaling.max_time
|
|
2054
|
+
) - self.timescaling.transform(self.timescaling.max_time + root_branch_len)
|
|
2055
|
+
|
|
2056
|
+
def assign_x_coordinates(self):
|
|
2057
|
+
# Set up transformation for genome positions
|
|
2058
|
+
self.x_transform = lambda x: (
|
|
2059
|
+
(x - self.left_extent)
|
|
2060
|
+
/ (self.right_extent - self.left_extent)
|
|
2061
|
+
* self.plotbox.width
|
|
2062
|
+
+ self.plotbox.left
|
|
2063
|
+
)
|
|
2064
|
+
# Set up x positions for nodes
|
|
2065
|
+
node_xpos = {}
|
|
2066
|
+
untracked_children = collections.defaultdict(list)
|
|
2067
|
+
self.extra_line = {} # To store a dotted line to represent polytomies
|
|
2068
|
+
leaf_x = 0 # First leaf starts at x=1, to give some space between Y axis & leaf
|
|
2069
|
+
tree = self.tree
|
|
2070
|
+
prev = tree.virtual_root
|
|
2071
|
+
for u in self.postorder_nodes:
|
|
2072
|
+
parent = tree.parent(u)
|
|
2073
|
+
omit = self.pack_untracked_polytomies and tree.num_tracked_samples(u) == 0
|
|
2074
|
+
if parent == prev:
|
|
2075
|
+
raise ValueError("Nodes must be passed in postorder to Tree.draw_svg()")
|
|
2076
|
+
is_tip = tree.parent(prev) != u
|
|
2077
|
+
if is_tip:
|
|
2078
|
+
if not omit:
|
|
2079
|
+
leaf_x += 1
|
|
2080
|
+
node_xpos[u] = leaf_x
|
|
2081
|
+
elif not omit:
|
|
2082
|
+
# Untracked children are available for packing into a polytomy summary
|
|
2083
|
+
untracked_children = []
|
|
2084
|
+
if self.pack_untracked_polytomies:
|
|
2085
|
+
untracked_children += [
|
|
2086
|
+
c for c in tree.children(u) if tree.num_tracked_samples(c) == 0
|
|
2087
|
+
]
|
|
2088
|
+
child_x = [node_xpos[c] for c in tree.children(u) if c in node_xpos]
|
|
2089
|
+
if len(untracked_children) > 0:
|
|
2090
|
+
if len(untracked_children) <= 1:
|
|
2091
|
+
# If only a single non-focal lineage, treat it as a condensed tip
|
|
2092
|
+
for child in untracked_children:
|
|
2093
|
+
leaf_x += 1
|
|
2094
|
+
node_xpos[child] = leaf_x
|
|
2095
|
+
child_x.append(leaf_x)
|
|
2096
|
+
else:
|
|
2097
|
+
# Otherwise show a horizontal line with the number of lineages
|
|
2098
|
+
# Extra length of line is equal to log of the polytomy size
|
|
2099
|
+
self.extra_line[u] = self.PolytomyLine(
|
|
2100
|
+
len(untracked_children),
|
|
2101
|
+
sum(tree.num_samples(v) for v in untracked_children),
|
|
2102
|
+
[leaf_x, leaf_x + 1 + np.log(len(untracked_children))],
|
|
2103
|
+
)
|
|
2104
|
+
child_x.append(leaf_x + 1)
|
|
2105
|
+
leaf_x = self.extra_line[u].line_pos[1]
|
|
2106
|
+
assert len(child_x) != 0 # Must have prev hit somethng defined as a tip
|
|
2107
|
+
if len(child_x) == 1:
|
|
2108
|
+
node_xpos[u] = child_x[0]
|
|
2109
|
+
else:
|
|
2110
|
+
a = min(child_x)
|
|
2111
|
+
b = max(child_x)
|
|
2112
|
+
node_xpos[u] = a + (b - a) / 2
|
|
2113
|
+
prev = u
|
|
2114
|
+
# Now rescale to the plot width: leaf_x is the maximum value of the last leaf
|
|
2115
|
+
if len(node_xpos) > 0:
|
|
2116
|
+
scale = self.plotbox.width / leaf_x
|
|
2117
|
+
lft = self.plotbox.left - scale / 2
|
|
2118
|
+
self.node_x_coord = {k: lft + v * scale for k, v in node_xpos.items()}
|
|
2119
|
+
for v in self.extra_line.values():
|
|
2120
|
+
for i in range(len(v.line_pos)):
|
|
2121
|
+
v.line_pos[i] = lft + v.line_pos[i] * scale
|
|
2122
|
+
|
|
2123
|
+
def info_classes(self, focal_node_id):
|
|
2124
|
+
"""
|
|
2125
|
+
For a focal node id, return a set of classes that encode this useful information:
|
|
2126
|
+
"a<X>" or "root": where <X> == id of immediate ancestor (parent) node
|
|
2127
|
+
"i<I>": where <I> == individual id
|
|
2128
|
+
"p<P>": where <P> == population id
|
|
2129
|
+
"n<Y>": where <Y> == focal node id
|
|
2130
|
+
"m<A>": where <A> == mutation id
|
|
2131
|
+
"s<B>": where <B> == site id of all mutations
|
|
2132
|
+
"c<N>" or "leaf": where <N> == number of direct children of this node
|
|
2133
|
+
"""
|
|
2134
|
+
# Add a new group for each node, and give it classes for css targetting
|
|
2135
|
+
focal_node = self.ts.node(focal_node_id)
|
|
2136
|
+
classes = set()
|
|
2137
|
+
classes.add(f"node n{focal_node_id}")
|
|
2138
|
+
if focal_node.individual != NULL:
|
|
2139
|
+
classes.add(f"i{focal_node.individual}")
|
|
2140
|
+
if focal_node.population != NULL:
|
|
2141
|
+
classes.add(f"p{focal_node.population}")
|
|
2142
|
+
v = self.tree.parent(focal_node_id)
|
|
2143
|
+
if v == NULL:
|
|
2144
|
+
classes.add("root")
|
|
2145
|
+
else:
|
|
2146
|
+
classes.add(f"a{v}")
|
|
2147
|
+
if self.tree.is_sample(focal_node_id):
|
|
2148
|
+
classes.add("sample")
|
|
2149
|
+
if self.tree.is_leaf(focal_node_id):
|
|
2150
|
+
classes.add("leaf")
|
|
2151
|
+
else:
|
|
2152
|
+
classes.add(f"c{self.tree.num_children(focal_node_id)}")
|
|
2153
|
+
for mutation in self.node_mutations[focal_node_id]:
|
|
2154
|
+
# Adding mutations and sites above this node allows identification
|
|
2155
|
+
# of the tree under any specific mutation
|
|
2156
|
+
classes.add(f"m{mutation.id + self.offsets.mutation}")
|
|
2157
|
+
classes.add(f"s{mutation.site + self.offsets.site}")
|
|
2158
|
+
return sorted(classes)
|
|
2159
|
+
|
|
2160
|
+
def text_transform(self, position, dy=0):
|
|
2161
|
+
line_h = self.text_height
|
|
2162
|
+
sym_sz = self.symbol_size
|
|
2163
|
+
transforms = {
|
|
2164
|
+
"below": f"translate(0 {rnd(line_h - sym_sz / 2 + dy)})",
|
|
2165
|
+
"above": f"translate(0 {rnd(-(line_h - sym_sz / 2) + dy)})",
|
|
2166
|
+
"above_left": f"translate({rnd(-sym_sz / 2)} {rnd(-line_h / 2 + dy)})",
|
|
2167
|
+
"above_right": f"translate({rnd(sym_sz / 2)} {-rnd(line_h / 2 + dy)})",
|
|
2168
|
+
"left": f"translate({-rnd(2 + sym_sz / 2)} {rnd(dy)})",
|
|
2169
|
+
"right": f"translate({rnd(2 + sym_sz / 2)} {rnd(dy)})",
|
|
2170
|
+
}
|
|
2171
|
+
return transforms[position]
|
|
2172
|
+
|
|
2173
|
+
def draw_tree(self):
|
|
2174
|
+
# Note: the displayed tree may not be the same as self.tree, e.g. if the nodes
|
|
2175
|
+
# have been collapsed, or a subtree is being displayed. The node_x_coord
|
|
2176
|
+
# dictionary keys gives the nodes of the displayed tree, in postorder.
|
|
2177
|
+
NodeDrawInfo = collections.namedtuple("NodeDrawInfo", ["pos", "is_tip"])
|
|
2178
|
+
dwg = self.drawing
|
|
2179
|
+
tree = self.tree
|
|
2180
|
+
left_child = get_left_child(tree, self.postorder_nodes)
|
|
2181
|
+
parent_array = tree.parent_array
|
|
2182
|
+
edge_array = tree.edge_array
|
|
2183
|
+
|
|
2184
|
+
node_info = {}
|
|
2185
|
+
roots = [] # Roots of the displated tree
|
|
2186
|
+
prev = tree.virtual_root
|
|
2187
|
+
for u, x in self.node_x_coord.items(): # Node ids `u` returned in postorder
|
|
2188
|
+
node_info[u] = NodeDrawInfo(
|
|
2189
|
+
pos=np.array([x, self.timescaling.transform(self.node_height[u])]),
|
|
2190
|
+
# Detect if this is a "tip" in the displayed tree, even if
|
|
2191
|
+
# it is not a leaf in the original tree, by looking at the prev parent
|
|
2192
|
+
is_tip=(parent_array[prev] != u),
|
|
2193
|
+
)
|
|
2194
|
+
prev = u
|
|
2195
|
+
if parent_array[u] not in self.node_x_coord:
|
|
2196
|
+
roots.append(u)
|
|
2197
|
+
# Iterate over displayed nodes, adding groups to reflect the tree hierarchy
|
|
2198
|
+
stack = []
|
|
2199
|
+
for u in roots:
|
|
2200
|
+
x, y = node_info[u].pos
|
|
2201
|
+
grp = dwg.g(
|
|
2202
|
+
class_=" ".join(self.info_classes(u)),
|
|
2203
|
+
transform=f"translate({rnd(x)} {rnd(y)})",
|
|
2204
|
+
)
|
|
2205
|
+
stack.append((u, self.get_plotbox().add(grp)))
|
|
2206
|
+
|
|
2207
|
+
# Preorder traversal, so we can create nested groups
|
|
2208
|
+
while len(stack) > 0:
|
|
2209
|
+
u, curr_svg_group = stack.pop()
|
|
2210
|
+
pu, is_tip = node_info[u]
|
|
2211
|
+
for focal in tree.children(u):
|
|
2212
|
+
if focal not in node_info:
|
|
2213
|
+
continue
|
|
2214
|
+
fx, fy = node_info[focal].pos - pu
|
|
2215
|
+
new_svg_group = curr_svg_group.add(
|
|
2216
|
+
dwg.g(
|
|
2217
|
+
class_=" ".join(self.info_classes(focal)),
|
|
2218
|
+
transform=f"translate({rnd(fx)} {rnd(fy)})",
|
|
2219
|
+
)
|
|
2220
|
+
)
|
|
2221
|
+
stack.append((focal, new_svg_group))
|
|
2222
|
+
|
|
2223
|
+
o = (0, 0)
|
|
2224
|
+
v = parent_array[u]
|
|
2225
|
+
|
|
2226
|
+
# Add polytomy line if necessary
|
|
2227
|
+
if u in self.extra_line:
|
|
2228
|
+
info = self.extra_line[u]
|
|
2229
|
+
x2 = info.line_pos[1] - pu[0]
|
|
2230
|
+
poly = dwg.g(class_="polytomy")
|
|
2231
|
+
poly.add(
|
|
2232
|
+
dwg.line(
|
|
2233
|
+
start=(0, 0),
|
|
2234
|
+
end=(x2, 0),
|
|
2235
|
+
)
|
|
2236
|
+
)
|
|
2237
|
+
label = dwg.text(
|
|
2238
|
+
f"+{info.num_samples}/{bold_integer(info.num_branches)}",
|
|
2239
|
+
font_style="italic",
|
|
2240
|
+
x=[rnd(x2)],
|
|
2241
|
+
dy=[rnd(-self.text_height / 10)], # make the plus sign line up
|
|
2242
|
+
text_anchor="end",
|
|
2243
|
+
)
|
|
2244
|
+
label.set_desc(
|
|
2245
|
+
title=(
|
|
2246
|
+
f"This polytomy has {info.num_branches} additional branches, "
|
|
2247
|
+
f"leading to a total of {info.num_samples} descendant samples"
|
|
2248
|
+
)
|
|
2249
|
+
)
|
|
2250
|
+
poly.add(label)
|
|
2251
|
+
curr_svg_group.add(poly)
|
|
2252
|
+
|
|
2253
|
+
# Add edge above node first => on layer underneath anything else
|
|
2254
|
+
draw_edge_above_node = False
|
|
2255
|
+
try:
|
|
2256
|
+
dx, dy = node_info[v].pos - pu
|
|
2257
|
+
draw_edge_above_node = True
|
|
2258
|
+
except KeyError:
|
|
2259
|
+
# Must be a root
|
|
2260
|
+
root_branch_l = self.min_root_branch_plot_length
|
|
2261
|
+
if root_branch_l > 0:
|
|
2262
|
+
if len(self.node_mutations[u]) > 0:
|
|
2263
|
+
mtop = self.timescaling.transform(
|
|
2264
|
+
self.node_mutations[u][0].time
|
|
2265
|
+
)
|
|
2266
|
+
root_branch_l = max(root_branch_l, pu[1] - mtop)
|
|
2267
|
+
dx, dy = 0, -root_branch_l
|
|
2268
|
+
draw_edge_above_node = True
|
|
2269
|
+
if draw_edge_above_node:
|
|
2270
|
+
edge_id_class = (
|
|
2271
|
+
"root" if edge_array[u] == tskit.NULL else f"e{edge_array[u]}"
|
|
2272
|
+
)
|
|
2273
|
+
add_class(self.edge_attrs[u], f"edge {edge_id_class}")
|
|
2274
|
+
path = dwg.path(
|
|
2275
|
+
[("M", o), ("V", rnd(dy)), ("H", rnd(dx))], **self.edge_attrs[u]
|
|
2276
|
+
)
|
|
2277
|
+
curr_svg_group.add(path)
|
|
2278
|
+
|
|
2279
|
+
# Add mutation symbols + labels
|
|
2280
|
+
for mutation in self.node_mutations[u]:
|
|
2281
|
+
# TODO get rid of these manual positioning tweaks and add them
|
|
2282
|
+
# as offsets the user can access via a transform or something.
|
|
2283
|
+
dy = self.timescaling.transform(mutation.time) - pu[1]
|
|
2284
|
+
mutation_id = mutation.id + self.offsets.mutation
|
|
2285
|
+
mutation_class = (
|
|
2286
|
+
f"mut m{mutation_id} " f"s{mutation.site + self.offsets.site}"
|
|
2287
|
+
)
|
|
2288
|
+
# Use the real mutation ID here, since we are referencing into the ts
|
|
2289
|
+
if util.is_unknown_time(self.ts.mutation(mutation.id).time):
|
|
2290
|
+
mutation_class += " unknown_time"
|
|
2291
|
+
if mutation_id in self.mutations_outside_tree:
|
|
2292
|
+
mutation_class += " extra"
|
|
2293
|
+
mut_group = curr_svg_group.add(
|
|
2294
|
+
dwg.g(class_=mutation_class, transform=f"translate(0 {rnd(dy)})")
|
|
2295
|
+
)
|
|
2296
|
+
# A line from the mutation to the node below, normally hidden, but
|
|
2297
|
+
# revealable if we want to flag the path below a mutation
|
|
2298
|
+
mut_group.add(dwg.line(end=(0, -rnd(dy))))
|
|
2299
|
+
# Symbols
|
|
2300
|
+
symbol = mut_group.add(dwg.path(**self.mutation_attrs[mutation_id]))
|
|
2301
|
+
if mutation_id in self.mutation_titles:
|
|
2302
|
+
symbol.set_desc(title=self.mutation_titles[mutation_id])
|
|
2303
|
+
# Labels
|
|
2304
|
+
if u == left_child[parent_array[u]]:
|
|
2305
|
+
mut_label_class = "lft"
|
|
2306
|
+
transform = self.text_transform("left")
|
|
2307
|
+
else:
|
|
2308
|
+
mut_label_class = "rgt"
|
|
2309
|
+
transform = self.text_transform("right")
|
|
2310
|
+
add_class(self.mutation_label_attrs[mutation_id], mut_label_class)
|
|
2311
|
+
self.mutation_label_attrs[mutation_id]["transform"] = transform
|
|
2312
|
+
mut_group.add(dwg.text(**self.mutation_label_attrs[mutation_id]))
|
|
2313
|
+
|
|
2314
|
+
# Add node symbol + label (visually above the edge subtending this node)
|
|
2315
|
+
# -> symbols
|
|
2316
|
+
if tree.is_sample(u):
|
|
2317
|
+
symbol = curr_svg_group.add(dwg.rect(**self.node_attrs[u]))
|
|
2318
|
+
else:
|
|
2319
|
+
symbol = curr_svg_group.add(dwg.circle(**self.node_attrs[u]))
|
|
2320
|
+
multi_samples = None
|
|
2321
|
+
if (
|
|
2322
|
+
is_tip and tree.num_samples(u) > 1
|
|
2323
|
+
): # Multi-sample tip => trapezium shape
|
|
2324
|
+
multi_samples = tree.num_samples(u)
|
|
2325
|
+
trapezium_attrs = self.node_attrs[u].copy()
|
|
2326
|
+
# Remove the shape-styling attributes
|
|
2327
|
+
for unwanted_attr in ("size", "insert", "center", "r"):
|
|
2328
|
+
trapezium_attrs.pop(unwanted_attr, None)
|
|
2329
|
+
trapezium_attrs["points"] = [ # add a trapezium shape below the symbol
|
|
2330
|
+
(self.symbol_size / 2, 0),
|
|
2331
|
+
(self.symbol_size, self.symbol_size),
|
|
2332
|
+
(-self.symbol_size, self.symbol_size),
|
|
2333
|
+
(-self.symbol_size / 2, 0),
|
|
2334
|
+
]
|
|
2335
|
+
add_class(trapezium_attrs, "multi")
|
|
2336
|
+
curr_svg_group.add(dwg.polygon(**trapezium_attrs))
|
|
2337
|
+
if u in self.node_titles:
|
|
2338
|
+
symbol.set_desc(title=self.node_titles[u])
|
|
2339
|
+
# -> labels
|
|
2340
|
+
node_lab_attr = self.node_label_attrs[u]
|
|
2341
|
+
if is_tip and multi_samples is None:
|
|
2342
|
+
node_lab_attr["transform"] = self.text_transform("below")
|
|
2343
|
+
elif u in roots and self.min_root_branch_plot_length == 0:
|
|
2344
|
+
node_lab_attr["transform"] = self.text_transform("above")
|
|
2345
|
+
else:
|
|
2346
|
+
if multi_samples is not None:
|
|
2347
|
+
label = dwg.text(
|
|
2348
|
+
text=f"+{multi_samples}",
|
|
2349
|
+
transform=self.text_transform("below", dy=1),
|
|
2350
|
+
font_style="italic",
|
|
2351
|
+
class_="lab summary",
|
|
2352
|
+
)
|
|
2353
|
+
title = (
|
|
2354
|
+
f"A collapsed {'sample' if tree.is_sample(u) else 'non-sample'} "
|
|
2355
|
+
f"node with {multi_samples} descendant samples in this tree"
|
|
2356
|
+
)
|
|
2357
|
+
label.set_desc(title=title)
|
|
2358
|
+
curr_svg_group.add(label)
|
|
2359
|
+
if u == left_child[tree.parent(u)]:
|
|
2360
|
+
add_class(node_lab_attr, "lft")
|
|
2361
|
+
node_lab_attr["transform"] = self.text_transform("above_left")
|
|
2362
|
+
else:
|
|
2363
|
+
add_class(node_lab_attr, "rgt")
|
|
2364
|
+
node_lab_attr["transform"] = self.text_transform("above_right")
|
|
2365
|
+
curr_svg_group.add(dwg.text(**node_lab_attr))
|
|
2366
|
+
|
|
2367
|
+
|
|
2368
|
+
class TextTreeSequence:
|
|
2369
|
+
"""
|
|
2370
|
+
Draw a tree sequence as horizontal line of trees.
|
|
2371
|
+
"""
|
|
2372
|
+
|
|
2373
|
+
def __init__(
|
|
2374
|
+
self,
|
|
2375
|
+
ts,
|
|
2376
|
+
node_labels=None,
|
|
2377
|
+
use_ascii=False,
|
|
2378
|
+
time_label_format=None,
|
|
2379
|
+
position_label_format=None,
|
|
2380
|
+
order=None,
|
|
2381
|
+
):
|
|
2382
|
+
self.ts = ts
|
|
2383
|
+
|
|
2384
|
+
time_label_format = "{:.2f}" if time_label_format is None else time_label_format
|
|
2385
|
+
tick_labels = ts.breakpoints(as_array=True)
|
|
2386
|
+
if position_label_format is None:
|
|
2387
|
+
position_scale_labels = create_tick_labels(tick_labels)
|
|
2388
|
+
else:
|
|
2389
|
+
position_scale_labels = [
|
|
2390
|
+
position_label_format.format(x) for x in tick_labels
|
|
2391
|
+
]
|
|
2392
|
+
|
|
2393
|
+
time = ts.tables.nodes.time
|
|
2394
|
+
time_scale_labels = [
|
|
2395
|
+
time_label_format.format(time[u]) for u in range(ts.num_nodes)
|
|
2396
|
+
]
|
|
2397
|
+
|
|
2398
|
+
trees = [
|
|
2399
|
+
VerticalTextTree(
|
|
2400
|
+
tree,
|
|
2401
|
+
max_time="ts",
|
|
2402
|
+
node_labels=node_labels,
|
|
2403
|
+
use_ascii=use_ascii,
|
|
2404
|
+
order=order,
|
|
2405
|
+
)
|
|
2406
|
+
for tree in self.ts.trees()
|
|
2407
|
+
]
|
|
2408
|
+
|
|
2409
|
+
self.height = 1 + max(tree.height for tree in trees)
|
|
2410
|
+
self.width = sum(tree.width + 2 for tree in trees) - 1
|
|
2411
|
+
max_time_scale_label_len = max(map(len, time_scale_labels))
|
|
2412
|
+
self.width += 3 + max_time_scale_label_len + len(position_scale_labels[-1]) // 2
|
|
2413
|
+
|
|
2414
|
+
self.canvas = np.zeros((self.height, self.width), dtype=str)
|
|
2415
|
+
self.canvas[:] = " "
|
|
2416
|
+
|
|
2417
|
+
vertical_sep = "|" if use_ascii else "┊"
|
|
2418
|
+
x = 0
|
|
2419
|
+
time_position = trees[0].time_position
|
|
2420
|
+
for u, label in enumerate(map(to_np_unicode, time_scale_labels)):
|
|
2421
|
+
y = time_position[u]
|
|
2422
|
+
self.canvas[y, 0 : label.shape[0]] = label
|
|
2423
|
+
self.canvas[:, max_time_scale_label_len] = vertical_sep
|
|
2424
|
+
x = 2 + max_time_scale_label_len
|
|
2425
|
+
|
|
2426
|
+
for j, tree in enumerate(trees):
|
|
2427
|
+
pos_label = to_np_unicode(position_scale_labels[j])
|
|
2428
|
+
k = len(pos_label)
|
|
2429
|
+
label_x = max(x - k // 2 - 2, 0)
|
|
2430
|
+
self.canvas[-1, label_x : label_x + k] = pos_label
|
|
2431
|
+
h, w = tree.canvas.shape
|
|
2432
|
+
self.canvas[-h - 1 : -1, x : x + w - 1] = tree.canvas[:, :-1]
|
|
2433
|
+
x += w
|
|
2434
|
+
self.canvas[:, x] = vertical_sep
|
|
2435
|
+
x += 2
|
|
2436
|
+
|
|
2437
|
+
pos_label = to_np_unicode(position_scale_labels[-1])
|
|
2438
|
+
k = len(pos_label)
|
|
2439
|
+
label_x = max(x - k // 2 - 2, 0)
|
|
2440
|
+
self.canvas[-1, label_x : label_x + k] = pos_label
|
|
2441
|
+
self.canvas[:, -1] = "\n"
|
|
2442
|
+
|
|
2443
|
+
def __str__(self):
|
|
2444
|
+
return "".join(self.canvas.reshape(self.width * self.height))
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
def to_np_unicode(string):
|
|
2448
|
+
"""
|
|
2449
|
+
Converts the specified string to a numpy unicode array.
|
|
2450
|
+
"""
|
|
2451
|
+
# TODO: what's the clean of doing this with numpy?
|
|
2452
|
+
# It really wants to create a zero-d Un array here
|
|
2453
|
+
# which breaks the assignment below and we end up
|
|
2454
|
+
# with n copies of the first char.
|
|
2455
|
+
n = len(string)
|
|
2456
|
+
np_string = np.zeros(n, dtype="U")
|
|
2457
|
+
for j in range(n):
|
|
2458
|
+
np_string[j] = string[j]
|
|
2459
|
+
return np_string
|
|
2460
|
+
|
|
2461
|
+
|
|
2462
|
+
def get_left_neighbour(tree, traversal_order):
|
|
2463
|
+
"""
|
|
2464
|
+
Returns the left-most neighbour of each node in the tree according to the
|
|
2465
|
+
specified traversal order. The left neighbour is the closest node in terms
|
|
2466
|
+
of path distance to the left of a given node.
|
|
2467
|
+
"""
|
|
2468
|
+
# The traversal order will define the order of children and roots.
|
|
2469
|
+
# Root order is defined by this traversal, and the roots are
|
|
2470
|
+
# the children of -1
|
|
2471
|
+
children = collections.defaultdict(list)
|
|
2472
|
+
for u in tree.nodes(order=traversal_order):
|
|
2473
|
+
children[tree.parent(u)].append(u)
|
|
2474
|
+
|
|
2475
|
+
left_neighbour = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)
|
|
2476
|
+
|
|
2477
|
+
def find_neighbours(u, neighbour):
|
|
2478
|
+
left_neighbour[u] = neighbour
|
|
2479
|
+
for v in children[u]:
|
|
2480
|
+
find_neighbours(v, neighbour)
|
|
2481
|
+
neighbour = v
|
|
2482
|
+
|
|
2483
|
+
# The children of -1 are the roots and the neighbour of all left-most
|
|
2484
|
+
# nodes in the tree is also -1 (NULL)
|
|
2485
|
+
find_neighbours(-1, -1)
|
|
2486
|
+
|
|
2487
|
+
return left_neighbour[:-1]
|
|
2488
|
+
|
|
2489
|
+
|
|
2490
|
+
def get_left_child(tree, postorder_nodes):
|
|
2491
|
+
"""
|
|
2492
|
+
Returns the left-most child of each node in the tree according to the
|
|
2493
|
+
traversal order listed in postorder_nodes. If a node has no children or
|
|
2494
|
+
NULL is passed in, return NULL.
|
|
2495
|
+
"""
|
|
2496
|
+
left_child = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)
|
|
2497
|
+
for u in postorder_nodes:
|
|
2498
|
+
parent = tree.parent(u)
|
|
2499
|
+
if parent != NULL and left_child[parent] == NULL:
|
|
2500
|
+
left_child[parent] = u
|
|
2501
|
+
return left_child
|
|
2502
|
+
|
|
2503
|
+
|
|
2504
|
+
def node_time_depth(tree, min_branch_length=None, max_time="tree"):
|
|
2505
|
+
"""
|
|
2506
|
+
Returns a dictionary mapping nodes in the specified tree to their depth
|
|
2507
|
+
in the specified tree (from the root direction). If min_branch_len is
|
|
2508
|
+
provided, it specifies the minimum length of each branch. If not specified,
|
|
2509
|
+
default to 1.
|
|
2510
|
+
"""
|
|
2511
|
+
if min_branch_length is None:
|
|
2512
|
+
min_branch_length = {u: 1 for u in range(tree.tree_sequence.num_nodes)}
|
|
2513
|
+
time_node_map = collections.defaultdict(list)
|
|
2514
|
+
current_depth = 0
|
|
2515
|
+
depth = {}
|
|
2516
|
+
# TODO this is basically the same code for the two cases. Refactor so that
|
|
2517
|
+
# we use the same code.
|
|
2518
|
+
if max_time == "tree":
|
|
2519
|
+
for u in tree.nodes():
|
|
2520
|
+
time_node_map[tree.time(u)].append(u)
|
|
2521
|
+
for t in sorted(time_node_map.keys()):
|
|
2522
|
+
for u in time_node_map[t]:
|
|
2523
|
+
for v in tree.children(u):
|
|
2524
|
+
current_depth = max(current_depth, depth[v] + min_branch_length[v])
|
|
2525
|
+
for u in time_node_map[t]:
|
|
2526
|
+
depth[u] = current_depth
|
|
2527
|
+
current_depth += 2
|
|
2528
|
+
for root in tree.roots:
|
|
2529
|
+
current_depth = max(current_depth, depth[root] + min_branch_length[root])
|
|
2530
|
+
else:
|
|
2531
|
+
assert max_time == "ts"
|
|
2532
|
+
ts = tree.tree_sequence
|
|
2533
|
+
for node in ts.nodes():
|
|
2534
|
+
time_node_map[node.time].append(node.id)
|
|
2535
|
+
node_edges = collections.defaultdict(list)
|
|
2536
|
+
for edge in ts.edges():
|
|
2537
|
+
node_edges[edge.parent].append(edge)
|
|
2538
|
+
|
|
2539
|
+
for t in sorted(time_node_map.keys()):
|
|
2540
|
+
for u in time_node_map[t]:
|
|
2541
|
+
for edge in node_edges[u]:
|
|
2542
|
+
v = edge.child
|
|
2543
|
+
current_depth = max(current_depth, depth[v] + min_branch_length[v])
|
|
2544
|
+
for u in time_node_map[t]:
|
|
2545
|
+
depth[u] = current_depth
|
|
2546
|
+
current_depth += 2
|
|
2547
|
+
|
|
2548
|
+
return depth, current_depth
|
|
2549
|
+
|
|
2550
|
+
|
|
2551
|
+
class TextTree:
|
|
2552
|
+
"""
|
|
2553
|
+
Draws a reprentation of a tree using unicode drawing characters written
|
|
2554
|
+
to a 2D array.
|
|
2555
|
+
"""
|
|
2556
|
+
|
|
2557
|
+
def __init__(
|
|
2558
|
+
self,
|
|
2559
|
+
tree,
|
|
2560
|
+
node_labels=None,
|
|
2561
|
+
max_time=None,
|
|
2562
|
+
min_time=None,
|
|
2563
|
+
use_ascii=False,
|
|
2564
|
+
orientation=None,
|
|
2565
|
+
order=None,
|
|
2566
|
+
):
|
|
2567
|
+
self.tree = tree
|
|
2568
|
+
self.traversal_order = check_order(order)
|
|
2569
|
+
self.max_time = check_max_time(max_time, allow_numeric=False)
|
|
2570
|
+
self.min_time = check_min_time(min_time, allow_numeric=False)
|
|
2571
|
+
self.use_ascii = use_ascii
|
|
2572
|
+
self.orientation = check_orientation(orientation)
|
|
2573
|
+
self.horizontal_line_char = "━"
|
|
2574
|
+
self.vertical_line_char = "┃"
|
|
2575
|
+
if use_ascii:
|
|
2576
|
+
self.horizontal_line_char = "-"
|
|
2577
|
+
self.vertical_line_char = "|"
|
|
2578
|
+
# These are set below by the placement algorithms.
|
|
2579
|
+
self.width = None
|
|
2580
|
+
self.height = None
|
|
2581
|
+
self.canvas = None
|
|
2582
|
+
# Placement of nodes in the 2D space. Nodes are positioned in one
|
|
2583
|
+
# dimension based on traversal ordering and by their time in the
|
|
2584
|
+
# other dimension. These are mapped to x and y coordinates according
|
|
2585
|
+
# to the orientation.
|
|
2586
|
+
self.traversal_position = {} # Position of nodes in traversal space
|
|
2587
|
+
self.time_position = {}
|
|
2588
|
+
# Labels for nodes
|
|
2589
|
+
self.node_labels = {}
|
|
2590
|
+
|
|
2591
|
+
# Set the node labels
|
|
2592
|
+
for u in tree.nodes():
|
|
2593
|
+
if node_labels is None:
|
|
2594
|
+
# If we don't specify node_labels, default to node ID
|
|
2595
|
+
self.node_labels[u] = str(u)
|
|
2596
|
+
else:
|
|
2597
|
+
# If we do specify node_labels, default to an empty line
|
|
2598
|
+
self.node_labels[u] = self.default_node_label
|
|
2599
|
+
if node_labels is not None:
|
|
2600
|
+
for node, label in node_labels.items():
|
|
2601
|
+
self.node_labels[node] = label
|
|
2602
|
+
|
|
2603
|
+
self._assign_time_positions()
|
|
2604
|
+
self._assign_traversal_positions()
|
|
2605
|
+
self.canvas = np.zeros((self.height, self.width), dtype=str)
|
|
2606
|
+
self.canvas[:] = " "
|
|
2607
|
+
self._draw()
|
|
2608
|
+
self.canvas[:, -1] = "\n"
|
|
2609
|
+
|
|
2610
|
+
def __str__(self):
|
|
2611
|
+
return "".join(self.canvas.reshape(self.width * self.height))
|
|
2612
|
+
|
|
2613
|
+
|
|
2614
|
+
class VerticalTextTree(TextTree):
|
|
2615
|
+
"""
|
|
2616
|
+
Text tree rendering where root nodes are at the top and time goes downwards
|
|
2617
|
+
into the present.
|
|
2618
|
+
"""
|
|
2619
|
+
|
|
2620
|
+
@property
|
|
2621
|
+
def default_node_label(self):
|
|
2622
|
+
return self.vertical_line_char
|
|
2623
|
+
|
|
2624
|
+
def _assign_time_positions(self):
|
|
2625
|
+
tree = self.tree
|
|
2626
|
+
# TODO when we add mutations to the text tree we'll need to take it into
|
|
2627
|
+
# account here. Presumably we need to get the maximum number of mutations
|
|
2628
|
+
# per branch.
|
|
2629
|
+
self.time_position, total_depth = node_time_depth(tree, max_time=self.max_time)
|
|
2630
|
+
self.height = total_depth - 1
|
|
2631
|
+
|
|
2632
|
+
def _assign_traversal_positions(self):
|
|
2633
|
+
self.label_x = {}
|
|
2634
|
+
left_neighbour = get_left_neighbour(self.tree, self.traversal_order)
|
|
2635
|
+
x = 0
|
|
2636
|
+
for u in self.tree.nodes(order=self.traversal_order):
|
|
2637
|
+
label_size = len(self.node_labels[u])
|
|
2638
|
+
if self.tree.is_leaf(u):
|
|
2639
|
+
self.traversal_position[u] = x + label_size // 2
|
|
2640
|
+
self.label_x[u] = x
|
|
2641
|
+
x += label_size + 1
|
|
2642
|
+
else:
|
|
2643
|
+
coords = [self.traversal_position[c] for c in self.tree.children(u)]
|
|
2644
|
+
if len(coords) == 1:
|
|
2645
|
+
self.traversal_position[u] = coords[0]
|
|
2646
|
+
else:
|
|
2647
|
+
a = min(coords)
|
|
2648
|
+
b = max(coords)
|
|
2649
|
+
child_mid = int(round(a + (b - a) / 2))
|
|
2650
|
+
self.traversal_position[u] = child_mid
|
|
2651
|
+
self.label_x[u] = self.traversal_position[u] - label_size // 2
|
|
2652
|
+
neighbour_x = -1
|
|
2653
|
+
neighbour = left_neighbour[u]
|
|
2654
|
+
if neighbour != NULL:
|
|
2655
|
+
neighbour_x = self.traversal_position[neighbour]
|
|
2656
|
+
self.label_x[u] = max(neighbour_x + 1, self.label_x[u])
|
|
2657
|
+
x = max(x, self.label_x[u] + label_size + 1)
|
|
2658
|
+
assert self.label_x[u] >= 0
|
|
2659
|
+
self.width = x
|
|
2660
|
+
|
|
2661
|
+
def _draw(self):
|
|
2662
|
+
if self.use_ascii:
|
|
2663
|
+
left_child = "+"
|
|
2664
|
+
right_child = "+"
|
|
2665
|
+
mid_parent = "+"
|
|
2666
|
+
mid_parent_child = "+"
|
|
2667
|
+
mid_child = "+"
|
|
2668
|
+
elif self.orientation == TOP:
|
|
2669
|
+
left_child = "┏"
|
|
2670
|
+
right_child = "┓"
|
|
2671
|
+
mid_parent = "┻"
|
|
2672
|
+
mid_parent_child = "╋"
|
|
2673
|
+
mid_child = "┳"
|
|
2674
|
+
else:
|
|
2675
|
+
left_child = "┗"
|
|
2676
|
+
right_child = "┛"
|
|
2677
|
+
mid_parent = "┳"
|
|
2678
|
+
mid_parent_child = "╋"
|
|
2679
|
+
mid_child = "┻"
|
|
2680
|
+
|
|
2681
|
+
for u in self.tree.nodes():
|
|
2682
|
+
xu = self.traversal_position[u]
|
|
2683
|
+
yu = self.time_position[u]
|
|
2684
|
+
label = to_np_unicode(self.node_labels[u])
|
|
2685
|
+
label_len = label.shape[0]
|
|
2686
|
+
label_x = self.label_x[u]
|
|
2687
|
+
assert label_x >= 0
|
|
2688
|
+
self.canvas[yu, label_x : label_x + label_len] = label
|
|
2689
|
+
children = self.tree.children(u)
|
|
2690
|
+
if len(children) > 0:
|
|
2691
|
+
if len(children) == 1:
|
|
2692
|
+
yv = self.time_position[children[0]]
|
|
2693
|
+
self.canvas[yv:yu, xu] = self.vertical_line_char
|
|
2694
|
+
else:
|
|
2695
|
+
left = min(self.traversal_position[v] for v in children)
|
|
2696
|
+
right = max(self.traversal_position[v] for v in children)
|
|
2697
|
+
y = yu - 1
|
|
2698
|
+
self.canvas[y, left + 1 : right] = self.horizontal_line_char
|
|
2699
|
+
self.canvas[y, xu] = mid_parent
|
|
2700
|
+
for v in children:
|
|
2701
|
+
xv = self.traversal_position[v]
|
|
2702
|
+
yv = self.time_position[v]
|
|
2703
|
+
self.canvas[yv:yu, xv] = self.vertical_line_char
|
|
2704
|
+
mid_char = mid_parent_child if xv == xu else mid_child
|
|
2705
|
+
self.canvas[y, xv] = mid_char
|
|
2706
|
+
self.canvas[y, left] = left_child
|
|
2707
|
+
self.canvas[y, right] = right_child
|
|
2708
|
+
if self.orientation == TOP:
|
|
2709
|
+
self.canvas = np.flip(self.canvas, axis=0)
|
|
2710
|
+
# Reverse the time positions so that we can use them in the tree
|
|
2711
|
+
# sequence drawing as well.
|
|
2712
|
+
flipped_time_position = {
|
|
2713
|
+
u: self.height - y - 1 for u, y in self.time_position.items()
|
|
2714
|
+
}
|
|
2715
|
+
self.time_position = flipped_time_position
|
|
2716
|
+
|
|
2717
|
+
|
|
2718
|
+
class HorizontalTextTree(TextTree):
|
|
2719
|
+
"""
|
|
2720
|
+
Text tree rendering where root nodes are at the left and time goes
|
|
2721
|
+
rightwards into the present.
|
|
2722
|
+
"""
|
|
2723
|
+
|
|
2724
|
+
@property
|
|
2725
|
+
def default_node_label(self):
|
|
2726
|
+
return self.horizontal_line_char
|
|
2727
|
+
|
|
2728
|
+
def _assign_time_positions(self):
|
|
2729
|
+
# TODO when we add mutations to the text tree we'll need to take it into
|
|
2730
|
+
# account here. Presumably we need to get the maximum number of mutations
|
|
2731
|
+
# per branch.
|
|
2732
|
+
self.time_position, total_depth = node_time_depth(
|
|
2733
|
+
self.tree, {u: 1 + len(self.node_labels[u]) for u in self.tree.nodes()}
|
|
2734
|
+
)
|
|
2735
|
+
self.width = total_depth
|
|
2736
|
+
|
|
2737
|
+
def _assign_traversal_positions(self):
|
|
2738
|
+
y = 0
|
|
2739
|
+
for root in self.tree.roots:
|
|
2740
|
+
for u in self.tree.nodes(root, order=self.traversal_order):
|
|
2741
|
+
if self.tree.is_leaf(u):
|
|
2742
|
+
self.traversal_position[u] = y
|
|
2743
|
+
y += 2
|
|
2744
|
+
else:
|
|
2745
|
+
coords = [self.traversal_position[c] for c in self.tree.children(u)]
|
|
2746
|
+
if len(coords) == 1:
|
|
2747
|
+
self.traversal_position[u] = coords[0]
|
|
2748
|
+
else:
|
|
2749
|
+
a = min(coords)
|
|
2750
|
+
b = max(coords)
|
|
2751
|
+
child_mid = int(round(a + (b - a) / 2))
|
|
2752
|
+
self.traversal_position[u] = child_mid
|
|
2753
|
+
y += 1
|
|
2754
|
+
self.height = y - 2
|
|
2755
|
+
|
|
2756
|
+
def _draw(self):
|
|
2757
|
+
if self.use_ascii:
|
|
2758
|
+
top_across = "+"
|
|
2759
|
+
bot_across = "+"
|
|
2760
|
+
mid_parent = "+"
|
|
2761
|
+
mid_parent_child = "+"
|
|
2762
|
+
mid_child = "+"
|
|
2763
|
+
elif self.orientation == LEFT:
|
|
2764
|
+
top_across = "┏"
|
|
2765
|
+
bot_across = "┗"
|
|
2766
|
+
mid_parent = "┫"
|
|
2767
|
+
mid_parent_child = "╋"
|
|
2768
|
+
mid_child = "┣"
|
|
2769
|
+
else:
|
|
2770
|
+
top_across = "┓"
|
|
2771
|
+
bot_across = "┛"
|
|
2772
|
+
mid_parent = "┣"
|
|
2773
|
+
mid_parent_child = "╋"
|
|
2774
|
+
mid_child = "┫"
|
|
2775
|
+
|
|
2776
|
+
# Draw in root-right mode as the coordinates go in the expected direction.
|
|
2777
|
+
for u in self.tree.nodes():
|
|
2778
|
+
yu = self.traversal_position[u]
|
|
2779
|
+
xu = self.time_position[u]
|
|
2780
|
+
label = to_np_unicode(self.node_labels[u])
|
|
2781
|
+
if self.orientation == LEFT:
|
|
2782
|
+
# We flip the array at the end so need to reverse the label.
|
|
2783
|
+
label = label[::-1]
|
|
2784
|
+
label_len = label.shape[0]
|
|
2785
|
+
self.canvas[yu, xu : xu + label_len] = label
|
|
2786
|
+
children = self.tree.children(u)
|
|
2787
|
+
if len(children) > 0:
|
|
2788
|
+
if len(children) == 1:
|
|
2789
|
+
xv = self.time_position[children[0]]
|
|
2790
|
+
self.canvas[yu, xv:xu] = self.horizontal_line_char
|
|
2791
|
+
else:
|
|
2792
|
+
bot = min(self.traversal_position[v] for v in children)
|
|
2793
|
+
top = max(self.traversal_position[v] for v in children)
|
|
2794
|
+
x = xu - 1
|
|
2795
|
+
self.canvas[bot + 1 : top, x] = self.vertical_line_char
|
|
2796
|
+
self.canvas[yu, x] = mid_parent
|
|
2797
|
+
for v in children:
|
|
2798
|
+
yv = self.traversal_position[v]
|
|
2799
|
+
xv = self.time_position[v]
|
|
2800
|
+
self.canvas[yv, xv:x] = self.horizontal_line_char
|
|
2801
|
+
mid_char = mid_parent_child if yv == yu else mid_child
|
|
2802
|
+
self.canvas[yv, x] = mid_char
|
|
2803
|
+
self.canvas[bot, x] = top_across
|
|
2804
|
+
self.canvas[top, x] = bot_across
|
|
2805
|
+
if self.orientation == LEFT:
|
|
2806
|
+
self.canvas = np.flip(self.canvas, axis=1)
|
|
2807
|
+
# Move the padding to the left.
|
|
2808
|
+
self.canvas[:, :-1] = self.canvas[:, 1:]
|
|
2809
|
+
self.canvas[:, -1] = " "
|