tskit 1.0.1__cp314-cp314-macosx_10_15_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tskit/drawing.py ADDED
@@ -0,0 +1,2809 @@
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2018-2025 Tskit Developers
4
+ # Copyright (c) 2015-2017 University of Oxford
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ """
24
+ Module responsible for visualisations.
25
+ """
26
+ import collections
27
+ import itertools
28
+ import logging
29
+ import math
30
+ import numbers
31
+ import operator
32
+ import warnings
33
+ import xml.dom.minidom
34
+ from collections.abc import Mapping
35
+ from dataclasses import dataclass
36
+
37
+ import numpy as np
38
+
39
+ import tskit
40
+ import tskit.util as util
41
+ from _tskit import NODE_IS_SAMPLE
42
+ from _tskit import NULL
43
+
44
+ LEFT = "left"
45
+ RIGHT = "right"
46
+ TOP = "top"
47
+ BOTTOM = "bottom"
48
+
49
+ # constants for whether to plot a tree in a tree sequence
50
+ OMIT = 1
51
+ LEFT_CLIP = 2
52
+ RIGHT_CLIP = 4
53
+ OMIT_MIDDLE = 8
54
+
55
+
56
+ # Minimal SVG generation module to replace svgwrite for tskit visualization.
57
+ # This implementation provides only the functionality needed for the visualization
58
+ # code while maintaining the same API as svgwrite.
59
+
60
+
61
+ class Element:
62
+ def __init__(self, tag, **kwargs):
63
+ self.tag = tag
64
+ self.attrs = {}
65
+ self.children = []
66
+
67
+ # Process kwargs in alphabetical order
68
+ for key in sorted(kwargs.keys()):
69
+ value = kwargs[key]
70
+ # Handle class_ special case for class attribute
71
+ if key.endswith("_"):
72
+ key = key[:-1]
73
+ key = key.replace("_", "-")
74
+ self.attrs[key] = value
75
+
76
+ def __getitem__(self, key):
77
+ return self.attrs.get(key, "")
78
+
79
+ def __setitem__(self, key, value):
80
+ self.attrs[key] = value
81
+
82
+ def add(self, child):
83
+ self.children.append(child)
84
+ return child
85
+
86
+ def set_desc(self, **kwargs):
87
+ if "title" in kwargs:
88
+ title_elem = Element("title")
89
+ title_elem.children.append(kwargs["title"])
90
+ self.children.append(title_elem)
91
+ return self
92
+
93
+ def _attr_str(self):
94
+ result = []
95
+ for key, value in self.attrs.items():
96
+ if isinstance(value, (list, tuple)):
97
+ # Handle points lists (for polygon/polyline)
98
+ if key == "points":
99
+ points_str = " ".join(f"{x},{y}" for x, y in value)
100
+ result.append(f'{key}="{points_str}"')
101
+ else:
102
+ result.append(f'{key}="{" ".join(map(str, value))}"')
103
+ else:
104
+ result.append(f'{key}="{value}"')
105
+ return " ".join(result)
106
+
107
+ def tostring(self):
108
+ stack = [(self, False)]
109
+ result = []
110
+
111
+ while stack:
112
+ elem, is_closing_tag = stack.pop()
113
+ if is_closing_tag:
114
+ result.append(f"</{elem.tag}>")
115
+ continue
116
+ attr_str = elem._attr_str()
117
+ start = f"<{elem.tag}"
118
+ if attr_str:
119
+ start += f" {attr_str}"
120
+ if not elem.children:
121
+ result.append(f"{start}/>")
122
+ else:
123
+ result.append(f"{start}>")
124
+ stack.append((elem, True))
125
+ for child in reversed(elem.children):
126
+ if isinstance(child, Element):
127
+ stack.append((child, False))
128
+ else:
129
+ result.append(str(child))
130
+
131
+ return "".join(result)
132
+
133
+
134
+ class Drawing:
135
+ def __init__(self, size=None, **kwargs):
136
+ kwargs = {
137
+ "version": "1.1",
138
+ "xmlns": "http://www.w3.org/2000/svg",
139
+ "xmlns:ev": "http://www.w3.org/2001/xml-events",
140
+ "xmlns:xlink": "http://www.w3.org/1999/xlink",
141
+ "baseProfile": "full",
142
+ **kwargs,
143
+ }
144
+ if size is not None:
145
+ kwargs["width"] = size[0]
146
+ kwargs["height"] = size[1]
147
+
148
+ self.root = Element("svg", **kwargs)
149
+ self.root.add("") # First root elem is a blank preamble
150
+ self.defs = Element("defs")
151
+ self.root.add(self.defs)
152
+
153
+ def add(self, element):
154
+ return self.root.add(element)
155
+
156
+ def g(self, **kwargs):
157
+ return Element("g", **kwargs)
158
+
159
+ def rect(self, insert=None, size=None, **kwargs):
160
+ if insert:
161
+ kwargs["x"] = insert[0]
162
+ kwargs["y"] = insert[1]
163
+ if size:
164
+ kwargs["width"] = size[0]
165
+ kwargs["height"] = size[1]
166
+ return Element("rect", **kwargs)
167
+
168
+ def circle(self, center=None, r=None, **kwargs):
169
+ if center:
170
+ kwargs["cx"] = center[0]
171
+ kwargs["cy"] = center[1]
172
+ if r:
173
+ kwargs["r"] = r
174
+ return Element("circle", **kwargs)
175
+
176
+ def line(self, start=None, end=None, **kwargs):
177
+ if start:
178
+ kwargs["x1"] = start[0]
179
+ kwargs["y1"] = start[1]
180
+ else:
181
+ kwargs["x1"] = 0
182
+ kwargs["y1"] = 0
183
+ if end:
184
+ kwargs["x2"] = end[0]
185
+ kwargs["y2"] = end[1]
186
+ else:
187
+ kwargs["x2"] = 0 # pragma: not covered
188
+ kwargs["y2"] = 0 # pragma: not covered
189
+ return Element("line", **kwargs)
190
+
191
+ def polyline(self, points=None, **kwargs):
192
+ if points:
193
+ kwargs["points"] = points
194
+ return Element("polyline", **kwargs)
195
+
196
+ def polygon(self, points=None, **kwargs):
197
+ if points:
198
+ kwargs["points"] = points
199
+ return Element("polygon", **kwargs)
200
+
201
+ def path(self, d=None, **kwargs):
202
+ if isinstance(d, list):
203
+ # Convert path commands from tuples to string
204
+ path_str = ""
205
+ for cmd in d:
206
+ if isinstance(cmd, tuple) and len(cmd) >= 2:
207
+ cmd_letter = cmd[0]
208
+ # Handle nested tuples by flattening
209
+ params = []
210
+ for param in cmd[1:]:
211
+ if isinstance(param, tuple):
212
+ # Flatten tuple coordinates
213
+ params.extend(str(p) for p in param)
214
+ else:
215
+ params.append(str(param))
216
+ path_str += f"{cmd_letter} {' '.join(params)} "
217
+ kwargs["d"] = path_str.strip()
218
+ elif d:
219
+ kwargs["d"] = d
220
+ return Element("path", **kwargs)
221
+
222
+ def text(self, text=None, **kwargs):
223
+ elem = Element("text", **kwargs)
224
+ if text:
225
+ elem.children.append(text)
226
+ return elem
227
+
228
+ def style(self, content):
229
+ elem = Element("style", type="text/css")
230
+ if content:
231
+ # Use CDATA to avoid having to escape special characters in CSS
232
+ elem.children.append(f"<![CDATA[{content}]]>")
233
+ return elem
234
+
235
+ def tostring(self, pretty=False):
236
+ if pretty:
237
+ return xml.dom.minidom.parseString(self.root.tostring()).toprettyxml()
238
+ return self.root.tostring()
239
+
240
+ def saveas(self, path, pretty=False):
241
+ with open(path, "w", encoding="utf-8") as f:
242
+ f.write(self.tostring(pretty=pretty))
243
+
244
+
245
+ @dataclass
246
+ class Offsets:
247
+ "Used when x_lim set, and displayed ts has been cut down by keep_intervals"
248
+
249
+ tree: int = 0
250
+ site: int = 0
251
+ mutation: int = 0
252
+
253
+
254
+ @dataclass(frozen=True)
255
+ class Timescaling:
256
+ "Class used to transform the time axis"
257
+
258
+ max_time: float
259
+ min_time: float
260
+ plot_min: float
261
+ plot_range: float
262
+ use_log_transform: bool
263
+
264
+ def __post_init__(self):
265
+ if self.plot_range < 0:
266
+ raise ValueError("Image size too small to allow space to plot tree")
267
+ if self.use_log_transform:
268
+ if self.min_time < 0:
269
+ raise ValueError("Cannot use a log scale if there are negative times")
270
+ super().__setattr__("transform", self.log_transform)
271
+ else:
272
+ super().__setattr__("transform", self.linear_transform)
273
+
274
+ def log_transform(self, y):
275
+ "Standard log transform but allowing for values of 0 by adding 1"
276
+ delta = 1 if self.min_time == 0 else 0
277
+ log_max = np.log(self.max_time + delta)
278
+ log_min = np.log(self.min_time + delta)
279
+ y_scale = self.plot_range / (log_max - log_min)
280
+ return self.plot_min - (np.log(y + delta) - log_min) * y_scale
281
+
282
+ def linear_transform(self, y):
283
+ y_scale = self.plot_range / (self.max_time - self.min_time)
284
+ return self.plot_min - (y - self.min_time) * y_scale
285
+
286
+
287
+ class SVGString(str):
288
+ "A string containing an SVG representation"
289
+
290
+ def _repr_svg_(self):
291
+ """
292
+ Simply return the SVG string: called by jupyter notebooks to render trees.
293
+ """
294
+ return self
295
+
296
+
297
+ def check_orientation(orientation):
298
+ if orientation is None:
299
+ orientation = TOP
300
+ else:
301
+ orientation = orientation.lower()
302
+ orientations = [LEFT, RIGHT, TOP, BOTTOM]
303
+ if orientation not in orientations:
304
+ raise ValueError(f"Unknown orientiation: choose from {orientations}")
305
+ return orientation
306
+
307
+
308
+ def check_max_time(max_time, allow_numeric=True):
309
+ if max_time is None:
310
+ max_time = "tree"
311
+ is_numeric = isinstance(max_time, numbers.Real)
312
+ if max_time not in ["tree", "ts"] and not allow_numeric:
313
+ raise ValueError("max_time must be 'tree' or 'ts'")
314
+ if max_time not in ["tree", "ts"] and (allow_numeric and not is_numeric):
315
+ raise ValueError("max_time must be a numeric value or one of 'tree' or 'ts'")
316
+ return max_time
317
+
318
+
319
+ def check_min_time(min_time, allow_numeric=True):
320
+ if min_time is None:
321
+ min_time = "tree"
322
+ if allow_numeric:
323
+ is_numeric = isinstance(min_time, numbers.Real)
324
+ if min_time not in ["tree", "ts"] and not is_numeric:
325
+ raise ValueError(
326
+ "min_time must be a numeric value or one of 'tree' or 'ts'"
327
+ )
328
+ else:
329
+ if min_time not in ["tree", "ts"]:
330
+ raise ValueError("min_time must be 'tree' or 'ts'")
331
+ return min_time
332
+
333
+
334
+ def check_time_scale(time_scale):
335
+ if time_scale is None:
336
+ time_scale = "time"
337
+ if time_scale not in ["time", "log_time", "rank"]:
338
+ raise ValueError("time_scale must be 'time', 'log_time' or 'rank'")
339
+ return time_scale
340
+
341
+
342
+ def check_format(format): # noqa A002
343
+ if format is None:
344
+ format = "SVG" # noqa A001
345
+ fmt = format.lower()
346
+ supported_formats = ["svg", "ascii", "unicode"]
347
+ if fmt not in supported_formats:
348
+ raise ValueError(
349
+ "Unknown format '{}'. Supported formats are {}".format(
350
+ format, supported_formats
351
+ )
352
+ )
353
+ return fmt
354
+
355
+
356
+ def check_order(order):
357
+ """
358
+ Checks the specified drawing order is valid and returns the corresponding
359
+ tree traversal order.
360
+ """
361
+ if order is None:
362
+ order = "minlex"
363
+ traversal_orders = {
364
+ "minlex": "minlex_postorder",
365
+ "tree": "postorder",
366
+ }
367
+ # Silently accept a tree traversal order as a valid order, so we can
368
+ # call this check twice if necessary
369
+ if order in traversal_orders.values():
370
+ return order
371
+ if order not in traversal_orders:
372
+ raise ValueError(
373
+ f"Unknown display order '{order}'. "
374
+ f"Supported orders are {list(traversal_orders.keys())}"
375
+ )
376
+ return traversal_orders[order]
377
+
378
+
379
+ def check_x_scale(x_scale):
380
+ """
381
+ Checks the specified x_scale is valid and sets default if None
382
+ """
383
+ if x_scale is None:
384
+ x_scale = "physical"
385
+ x_scales = ["physical", "treewise"]
386
+ if x_scale not in x_scales:
387
+ raise ValueError(
388
+ f"Unknown display x_scale '{x_scale}'. " f"Supported orders are {x_scales}"
389
+ )
390
+ return x_scale
391
+
392
+
393
+ def check_x_lim(x_lim, max_x):
394
+ """
395
+ Checks the specified x_limits are valid and sets default if None.
396
+ """
397
+ if x_lim is None:
398
+ x_lim = (None, None)
399
+ if len(x_lim) != 2:
400
+ raise ValueError("The x_lim parameter must be a list of length 2, or None")
401
+ try:
402
+ if x_lim[0] is not None and x_lim[0] < 0:
403
+ raise ValueError("x_lim[0] cannot be negative")
404
+ if x_lim[1] is not None and x_lim[1] > max_x:
405
+ raise ValueError("x_lim[1] cannot be greater than the sequence length")
406
+ if x_lim[0] is not None and x_lim[1] is not None and x_lim[0] >= x_lim[1]:
407
+ raise ValueError("x_lim[0] must be less than x_lim[1]")
408
+ except TypeError:
409
+ raise TypeError("x_lim parameters must be numeric")
410
+ return x_lim
411
+
412
+
413
+ def check_y_axis(y_axis):
414
+ """
415
+ Checks the specified y_axis is valid and sets default if None.
416
+ """
417
+ if y_axis is None:
418
+ y_axis = False
419
+ if y_axis is True:
420
+ y_axis = "left"
421
+ if y_axis not in ["left", "right", False]:
422
+ raise ValueError(f"Unknown y_axis specification: '{y_axis}'.")
423
+ return y_axis
424
+
425
+
426
+ def create_tick_labels(tick_values, decimal_places=2):
427
+ """
428
+ If tick_values are numeric, round the labels to X decimal_places, but do not print
429
+ decimals if all values are integers
430
+ """
431
+ try:
432
+ integer_ticks = np.all(np.round(tick_values) == tick_values)
433
+ except TypeError:
434
+ return tick_values
435
+ label_precision = 0 if integer_ticks else decimal_places
436
+ return [f"{lab:.{label_precision}f}" for lab in tick_values]
437
+
438
+
439
+ def clip_ts(ts, x_min, x_max, max_num_trees=None):
440
+ """
441
+ Culls the edges of the tree sequence outside the limits of x_min and x_max if
442
+ necessary, and flags internal trees for omission if there are more than
443
+ max_num_trees in the tree sequence
444
+
445
+ Returns the new tree sequence using the same genomic scale, and an
446
+ array specifying which trees to actually plot from it. This array contains
447
+ information about whether a plotted tree was clipped, because clipping can
448
+ cause the rightmost and leftmost tree in this new TS to have reduced spans, and
449
+ should be displayed by omitting the appropriate breakpoint.
450
+
451
+ If x_min is None, we take it to be 0 if the first tree has edges or sites, or
452
+ ``min(edges.left)`` if the first tree represents an empty region.
453
+ Similarly, if x_max is None we take it to be ``ts.sequence_length`` if the last tree
454
+ has edges or mutations, or ``ts.last().interval.left`` if the last tree represents
455
+ an empty region.
456
+
457
+ To plot the full ts, including empty flanking regions, specify x_limits of
458
+ [0, seq_len].
459
+
460
+ """
461
+ edges = ts.tables.edges
462
+ sites = ts.tables.sites
463
+ offsets = Offsets()
464
+ if x_min is None:
465
+ if ts.num_edges == 0:
466
+ if ts.num_sites == 0:
467
+ raise ValueError(
468
+ "To plot an empty tree sequence, specify x_lim=[0, sequence_length]"
469
+ )
470
+ x_min = 0
471
+ else:
472
+ x_min = np.min(edges.left)
473
+ if ts.num_sites > 0 and np.min(sites.position) < x_min:
474
+ x_min = 0 # First region has no edges, but does have sites => keep
475
+ if x_max is None:
476
+ if ts.num_edges == 0:
477
+ if ts.num_sites == 0:
478
+ raise ValueError(
479
+ "To plot an empty tree sequence, specify x_lim=[0, sequence_length]"
480
+ )
481
+ x_max = ts.sequence_length
482
+ else:
483
+ x_max = np.max(edges.right)
484
+ if ts.num_sites > 0 and np.max(sites.position) > x_max:
485
+ x_max = ts.sequence_length # Last region has sites but no edges => keep
486
+
487
+ if max_num_trees is None:
488
+ max_num_trees = np.inf
489
+
490
+ if max_num_trees < 2:
491
+ raise ValueError("Must show at least 2 trees when clipping a tree sequence")
492
+
493
+ if (x_min > 0) or (x_max < ts.sequence_length):
494
+ old_breaks = ts.breakpoints(as_array=True)
495
+ offsets.tree = np.searchsorted(old_breaks, x_min, "right") - 2
496
+ offsets.site = np.searchsorted(sites.position, x_min)
497
+ offsets.mutation = np.searchsorted(ts.tables.mutations.site, offsets.site)
498
+ ts = ts.keep_intervals([[x_min, x_max]], simplify=False)
499
+ if ts.num_edges == 0:
500
+ raise ValueError(
501
+ f"Can't limit plotting from {x_min} to {x_max} as whole region is empty"
502
+ )
503
+ edges = ts.tables.edges
504
+ sites = ts.tables.sites
505
+ trees_start = np.min(edges.left)
506
+ trees_end = np.max(edges.right)
507
+ tree_status = np.zeros(ts.num_trees, dtype=np.uint8)
508
+ # Are the leftmost/rightmost regions completely empty - if so, don't plot them
509
+ if 0 < x_min <= trees_start and (
510
+ ts.num_sites == 0 or trees_start <= np.min(sites.position)
511
+ ):
512
+ tree_status[0] = OMIT
513
+ if trees_end <= x_max < ts.sequence_length and (
514
+ ts.num_sites == 0 or trees_end >= np.max(sites.position)
515
+ ):
516
+ tree_status[-1] = OMIT
517
+
518
+ # Which breakpoints are new ones, as a result of clipping
519
+ new_breaks = np.logical_not(np.isin(ts.breakpoints(as_array=True), old_breaks))
520
+ tree_status[new_breaks[:-1]] |= LEFT_CLIP
521
+ tree_status[new_breaks[1:]] |= RIGHT_CLIP
522
+ else:
523
+ tree_status = np.zeros(ts.num_trees, dtype=np.uint8)
524
+
525
+ first_tree = 1 if tree_status[0] & OMIT else 0
526
+ last_tree = ts.num_trees - 2 if tree_status[-1] & OMIT else ts.num_trees - 1
527
+ num_shown_trees = last_tree - first_tree + 1
528
+ if num_shown_trees > max_num_trees:
529
+ num_start_trees = max_num_trees // 2 + (1 if max_num_trees % 2 else 0)
530
+ num_end_trees = max_num_trees // 2
531
+ assert num_start_trees + num_end_trees == max_num_trees
532
+ tree_status[
533
+ (first_tree + num_start_trees) : (last_tree - num_end_trees + 1)
534
+ ] = (OMIT | OMIT_MIDDLE)
535
+
536
+ return ts, tree_status, offsets
537
+
538
+
539
+ def check_y_ticks(ticks: list | Mapping | None) -> Mapping:
540
+ """
541
+ Later we might want to implement a tick locator function, such that e.g. ticks=5
542
+ selects ~5 nicely spaced tick locations (with sensible behaviour for log scales)
543
+ """
544
+ if ticks is None:
545
+ return {}
546
+ if isinstance(ticks, Mapping):
547
+ return dict(zip(ticks, create_tick_labels(list(ticks.values()))))
548
+ return dict(zip(ticks, create_tick_labels(ticks)))
549
+
550
+
551
+ def rnd(x):
552
+ """
553
+ Round a number so that the output SVG doesn't have unneeded precision
554
+ """
555
+ digits = 6
556
+ if x == 0 or not math.isfinite(x):
557
+ return x
558
+ digits -= math.ceil(math.log10(abs(x)))
559
+ x = round(x, digits)
560
+ if int(x) == x:
561
+ return int(x)
562
+ return x
563
+
564
+
565
+ def bold_integer(number):
566
+ # For simple integers, it's easier to use bold unicode characters
567
+ # than to try to get the SVG to render a bold font for part of a string
568
+ return "".join("𝟎𝟏𝟐𝟑𝟒𝟓𝟔𝟕𝟖𝟗"[int(digit)] for digit in str(number))
569
+
570
+
571
+ def edge_and_sample_nodes(ts, omit_regions=None):
572
+ """
573
+ Return ids of nodes which are mentioned in an edge in this tree sequence or which
574
+ are samples: nodes not connected to an edge are often found if x_lim is specified.
575
+ """
576
+ if omit_regions is None or len(omit_regions) == 0:
577
+ ids = np.concatenate((ts.edges_child, ts.edges_parent))
578
+ else:
579
+ ids = np.array([], dtype=ts.edges_child.dtype)
580
+ edges = ts.tables.edges
581
+ assert omit_regions.shape[1] == 2
582
+ omit_regions = omit_regions.flatten()
583
+ assert np.all(omit_regions == np.unique(omit_regions)) # Check they're in order
584
+ use_regions = np.concatenate(([0.0], omit_regions, [ts.sequence_length]))
585
+ use_regions = use_regions.reshape(-1, 2)
586
+ for left, right in use_regions:
587
+ used_edges = edges[np.logical_and(edges.left >= left, edges.right < right)]
588
+ ids = np.concatenate((ids, used_edges.child, used_edges.parent))
589
+ return np.unique(
590
+ np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0]))
591
+ )
592
+
593
+
594
+ def _postorder_tracked_node_traversal(tree, root, collapse_tracked, key_dict=None):
595
+ # Postorder traversal that only descends into subtrees if they contain
596
+ # a tracked node. Additionally, if collapse_tracked is not None, it is
597
+ # interpreted as a proportion, so that we do not descend into a subtree if
598
+ # that proportion or greater of the samples in the subtree are tracked.
599
+ # If key_dict is provided, use this to sort the children. This allows
600
+ # us to put e.g. the subtrees containing the most tracked nodes first.
601
+ # Private function, for use only in drawing.postorder_tracked_minlex_traversal()
602
+
603
+ # If we deliberately specify the virtual root, it should also be returned
604
+ is_virtual_root = root == tree.virtual_root
605
+ if root == tskit.NULL:
606
+ root = tree.virtual_root
607
+ stack = [(root, False)]
608
+ while stack:
609
+ u, visited = stack.pop()
610
+ if visited:
611
+ if u != tree.virtual_root or is_virtual_root:
612
+ yield u
613
+ else:
614
+ if tree.num_children(u) == 0:
615
+ yield u
616
+ elif tree.num_tracked_samples(u) == 0:
617
+ yield u
618
+ elif (
619
+ collapse_tracked is not None
620
+ and tree.num_children(u) != 1
621
+ and tree.num_tracked_samples(u)
622
+ >= collapse_tracked * tree.num_samples(u)
623
+ ):
624
+ yield u
625
+ else:
626
+ stack.append((u, True))
627
+ if key_dict is None:
628
+ stack.extend((c, False) for c in tree.children(u))
629
+ else:
630
+ stack.extend(
631
+ sorted(
632
+ ((c, False) for c in tree.children(u)),
633
+ key=lambda v: key_dict[v[0]],
634
+ reverse=True,
635
+ )
636
+ )
637
+
638
+
639
+ def _postorder_tracked_minlex_traversal(tree, root=None, *, collapse_tracked=None):
640
+ """
641
+ Postorder traversal for drawing purposes that places child nodes with the
642
+ most tracked sample descendants first (then sorts ties by minlex on leaf node ids).
643
+ Additionally, this traversal only descends into subtrees if they contain a tracked
644
+ node, and may not descend into other subtree, if the ``collapse_tracked``
645
+ parameter is set to a numeric value. More specifically, if the proportion of
646
+ tracked samples in the subtree is greater than or equal to ``collapse_tracked``,
647
+ the subtree is not descended into.
648
+ """
649
+
650
+ key_dict = {}
651
+ parent_array = tree.parent_array
652
+ prev = tree.virtual_root
653
+ if root is None:
654
+ root = tskit.NULL
655
+ for u in _postorder_tracked_node_traversal(tree, root, collapse_tracked):
656
+ is_tip = parent_array[prev] != u
657
+ prev = u
658
+ if is_tip:
659
+ # Sort by number of tracked samples (desc), then by minlex
660
+ key_dict[u] = (-tree.num_tracked_samples(u), u)
661
+ else:
662
+ min_tip_id = min(key_dict[v][1] for v in tree.children(u) if v in key_dict)
663
+ key_dict[u] = (-tree.num_tracked_samples(u), min_tip_id)
664
+
665
+ return _postorder_tracked_node_traversal(
666
+ tree, root, collapse_tracked, key_dict=key_dict
667
+ )
668
+
669
+
670
+ def draw_tree(
671
+ tree,
672
+ width=None,
673
+ height=None,
674
+ node_labels=None,
675
+ node_colours=None,
676
+ mutation_labels=None,
677
+ mutation_colours=None,
678
+ format=None, # noqa A002
679
+ edge_colours=None,
680
+ time_scale=None,
681
+ tree_height_scale=None,
682
+ max_time=None,
683
+ min_time=None,
684
+ max_tree_height=None,
685
+ order=None,
686
+ omit_sites=None,
687
+ ):
688
+ if time_scale is None and tree_height_scale is not None:
689
+ time_scale = tree_height_scale
690
+ # Deprecated in 0.3.6
691
+ warnings.warn(
692
+ "tree_height_scale is deprecated; use time_scale instead",
693
+ FutureWarning,
694
+ stacklevel=4,
695
+ )
696
+ if max_time is None and max_tree_height is not None:
697
+ max_time = max_tree_height
698
+ # Deprecated in 0.3.6
699
+ warnings.warn(
700
+ "max_tree_height is deprecated; use max_time instead",
701
+ FutureWarning,
702
+ stacklevel=4,
703
+ )
704
+
705
+ # See tree.draw() for documentation on these arguments.
706
+ fmt = check_format(format)
707
+ if fmt == "svg":
708
+ if width is None:
709
+ width = 200
710
+ if height is None:
711
+ height = 200
712
+
713
+ def remap_style(original_map, new_key, none_value):
714
+ if original_map is None:
715
+ return None
716
+ new_map = {}
717
+ for key, value in original_map.items():
718
+ if value is None:
719
+ new_map[key] = {"style": none_value}
720
+ else:
721
+ new_map[key] = {"style": f"{new_key}:{value};"}
722
+ return new_map
723
+
724
+ # Set style rather than fill & stroke directly to override top stylesheet
725
+ # Old semantics were to not draw the node if colour is None.
726
+ # Setting opacity to zero has the same effect.
727
+ node_attrs = remap_style(node_colours, "fill", "fill-opacity:0;")
728
+ edge_attrs = remap_style(edge_colours, "stroke", "stroke-opacity:0;")
729
+ mutation_attrs = remap_style(mutation_colours, "fill", "fill-opacity:0;")
730
+
731
+ node_label_attrs = None
732
+ tree = SvgTree(
733
+ tree,
734
+ (width, height),
735
+ node_labels=node_labels,
736
+ mutation_labels=mutation_labels,
737
+ time_scale=time_scale,
738
+ max_time=max_time,
739
+ min_time=min_time,
740
+ node_attrs=node_attrs,
741
+ edge_attrs=edge_attrs,
742
+ node_label_attrs=node_label_attrs,
743
+ mutation_attrs=mutation_attrs,
744
+ order=order,
745
+ omit_sites=omit_sites,
746
+ )
747
+ return SVGString(tree.drawing.tostring())
748
+
749
+ else:
750
+ if width is not None:
751
+ raise ValueError("Text trees do not support width")
752
+ if height is not None:
753
+ raise ValueError("Text trees do not support height")
754
+ if mutation_labels is not None:
755
+ raise ValueError("Text trees do not support mutation_labels")
756
+ if mutation_colours is not None:
757
+ raise ValueError("Text trees do not support mutation_colours")
758
+ if node_colours is not None:
759
+ raise ValueError("Text trees do not support node_colours")
760
+ if edge_colours is not None:
761
+ raise ValueError("Text trees do not support edge_colours")
762
+ if time_scale is not None:
763
+ raise ValueError("Text trees do not support time_scale")
764
+
765
+ use_ascii = fmt == "ascii"
766
+ text_tree = VerticalTextTree(
767
+ tree,
768
+ node_labels=node_labels,
769
+ max_time=max_time,
770
+ min_time=min_time,
771
+ use_ascii=use_ascii,
772
+ orientation=TOP,
773
+ order=order,
774
+ )
775
+ return str(text_tree)
776
+
777
+
778
+ def add_class(attrs_dict, classes_str):
779
+ """Adds the classes_str to the 'class' key in attrs_dict, or creates it"""
780
+ try:
781
+ attrs_dict["class"] += " " + classes_str
782
+ except KeyError:
783
+ attrs_dict["class"] = classes_str
784
+
785
+
786
+ @dataclass
787
+ class Plotbox:
788
+ total_size: list
789
+ pad_top: float = 0
790
+ pad_left: float = 0
791
+ pad_bottom: float = 0
792
+ pad_right: float = 0
793
+
794
+ def set_padding(self, top, left, bottom, right):
795
+ self.pad_top = top
796
+ self.pad_left = left
797
+ self.pad_bottom = bottom
798
+ self.pad_right = right
799
+ self._check()
800
+
801
+ @property
802
+ def max_x(self):
803
+ return self.total_size[0]
804
+
805
+ @property
806
+ def max_y(self):
807
+ return self.total_size[1]
808
+
809
+ @property
810
+ def top(self): # Alias for consistency with top & bottom
811
+ return self.pad_top
812
+
813
+ @property
814
+ def left(self): # Alias for consistency with top & bottom
815
+ return self.pad_left
816
+
817
+ @property
818
+ def bottom(self):
819
+ return self.max_y - self.pad_bottom
820
+
821
+ @property
822
+ def right(self):
823
+ return self.max_x - self.pad_right
824
+
825
+ @property
826
+ def width(self):
827
+ return self.right - self.left
828
+
829
+ @property
830
+ def height(self):
831
+ return self.bottom - self.top
832
+
833
+ def __post_init__(self):
834
+ self._check()
835
+
836
+ def _check(self):
837
+ if self.width < 1 or self.height < 1:
838
+ raise ValueError("Image size too small to fit")
839
+
840
+ def draw(self, dwg, add_to, colour="grey"):
841
+ # used for debugging
842
+ add_to.add(
843
+ dwg.rect(
844
+ (0, 0),
845
+ (self.max_x, self.max_y),
846
+ fill="white",
847
+ fill_opacity=0,
848
+ stroke=colour,
849
+ stroke_dasharray="15,15",
850
+ class_="outer_plotbox",
851
+ )
852
+ )
853
+ add_to.add(
854
+ dwg.rect(
855
+ (self.left, self.top),
856
+ (self.width, self.height),
857
+ fill="white",
858
+ fill_opacity=0,
859
+ stroke=colour,
860
+ stroke_dasharray="5,5",
861
+ class_="inner_plotbox",
862
+ )
863
+ )
864
+
865
+
866
+ class SvgPlot:
867
+ """
868
+ The base class for plotting any box to canvas
869
+ """
870
+
871
+ text_height = 14 # May want to calculate this based on a font size
872
+ line_height = text_height * 1.2 # allowing padding above and below a line
873
+ default_width = 200 # for a single tree
874
+ default_height = 200
875
+
876
+ def __init__(
877
+ self,
878
+ size,
879
+ svg_class,
880
+ root_svg_attributes=None,
881
+ canvas_size=None,
882
+ preamble=None,
883
+ ):
884
+ """
885
+ Creates self.drawing, an svgwrite.Drawing object for further use, and populates
886
+ it with a base group. The root_groups will be populated with
887
+ items that can be accessed from the outside, such as the plotbox, axes, etc.
888
+ """
889
+
890
+ if root_svg_attributes is None:
891
+ root_svg_attributes = {}
892
+ if canvas_size is None:
893
+ canvas_size = size
894
+ dwg = Drawing(size=canvas_size, **root_svg_attributes)
895
+
896
+ self.preamble = preamble
897
+ self.image_size = size
898
+ self.plotbox = Plotbox(size)
899
+ self.root_groups = {}
900
+ self.svg_class = svg_class
901
+ self.timescaling = None
902
+ self.root_svg_attributes = root_svg_attributes
903
+ self.dwg_base = dwg.add(dwg.g(class_=svg_class))
904
+ self.drawing = dwg
905
+
906
+ def draw(self, path=None):
907
+ if self.preamble is not None:
908
+ self.drawing.root.children[0] = self.preamble
909
+ output = self.drawing.tostring()
910
+ if path is not None:
911
+ # TODO remove the 'pretty' when we are done debugging this.
912
+ self.drawing.saveas(path, pretty=True)
913
+ return SVGString(output)
914
+
915
+ def get_plotbox(self):
916
+ """
917
+ Get the svgwrite plotbox, creating it if necessary.
918
+ """
919
+ if "plotbox" not in self.root_groups:
920
+ dwg = self.drawing
921
+ self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox"))
922
+ return self.root_groups["plotbox"]
923
+
924
+ def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs):
925
+ """
926
+ Add the text to the elem within a group; allows text rotations to work smoothly,
927
+ otherwise, if x & y parameters are used to position text, rotations applied to
928
+ the text tag occur around the (0,0) point of the containing group
929
+ """
930
+ dwg = self.drawing
931
+ group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"}
932
+ if group_class is not None:
933
+ group_attributes["class_"] = group_class
934
+ grp = add_to.add(dwg.g(**group_attributes))
935
+ grp.add(dwg.text(text, **kwargs))
936
+
937
+
938
+ class SvgSkippedPlot(SvgPlot):
939
+ def __init__(
940
+ self,
941
+ size,
942
+ num_skipped,
943
+ ):
944
+ super().__init__(
945
+ size,
946
+ svg_class="skipped",
947
+ )
948
+ container = self.get_plotbox()
949
+ x = self.plotbox.width / 2
950
+ y = self.plotbox.height / 2
951
+ self.add_text_in_group(
952
+ f"{num_skipped} trees",
953
+ container,
954
+ (x, y - self.line_height / 2),
955
+ text_anchor="middle",
956
+ )
957
+ self.add_text_in_group(
958
+ "skipped", container, (x, y + self.line_height / 2), text_anchor="middle"
959
+ )
960
+
961
+
962
+ class SvgAxisPlot(SvgPlot):
963
+ """
964
+ The class used for plotting either a tree or a tree sequence as an SVG file
965
+ """
966
+
967
+ standard_style = (
968
+ ".background path {fill: #808080; fill-opacity: 0}"
969
+ ".background path:nth-child(odd) {fill-opacity: .1}"
970
+ ".x-regions rect {fill: yellow; stroke: black; opacity: 0.5}" # opaque 4 overlap
971
+ ".axes {font-size: 14px}"
972
+ ".x-axis .tick .lab {font-weight: bold; dominant-baseline: hanging}"
973
+ ".axes, .tree {font-size: 14px; text-anchor: middle}"
974
+ ".axes line, .edge {stroke: black; fill: none}"
975
+ ".axes .ax-skip {stroke-dasharray: 4}"
976
+ ".y-axis .grid {stroke: #FAFAFA}"
977
+ ".node > .sym {fill: black; stroke: none}"
978
+ ".site > .sym {stroke: black}"
979
+ ".mut text {fill: red; font-style: italic}"
980
+ ".mut.extra text {fill: hotpink}"
981
+ ".mut line {fill: none; stroke: none}" # Default hide mut line to expose edges
982
+ ".mut .sym {fill: none; stroke: red}"
983
+ ".mut.extra .sym {stroke: hotpink}"
984
+ ".node .mut .sym {stroke-width: 1.5px}"
985
+ ".tree text, .tree-sequence text {dominant-baseline: central}"
986
+ ".plotbox .lab.lft {text-anchor: end}"
987
+ ".plotbox .lab.rgt {text-anchor: start}"
988
+ ".polytomy line {stroke: black; stroke-dasharray: 1px, 1px}"
989
+ ".polytomy text {paint-order:stroke;stroke-width:0.3em;stroke:white}"
990
+ )
991
+
992
+ # TODO: we may want to make some of the constants below into parameters
993
+ root_branch_fraction = 1 / 8 # Rel root branch len, unless it has a timed mutation
994
+ default_tick_length = 5
995
+ default_tick_length_site = 10
996
+ # Placement of the axes lines within the padding - not used unless axis is plotted
997
+ default_x_axis_offset = 20
998
+ default_y_axis_offset = 40
999
+
1000
+ def __init__(
1001
+ self,
1002
+ ts,
1003
+ size,
1004
+ root_svg_attributes,
1005
+ style,
1006
+ svg_class,
1007
+ time_scale,
1008
+ x_axis=None,
1009
+ y_axis=None,
1010
+ x_label=None,
1011
+ y_label=None,
1012
+ offsets=None,
1013
+ debug_box=None,
1014
+ omit_sites=None,
1015
+ canvas_size=None,
1016
+ mutation_titles=None,
1017
+ preamble=None,
1018
+ ):
1019
+ super().__init__(
1020
+ size,
1021
+ svg_class,
1022
+ root_svg_attributes,
1023
+ canvas_size,
1024
+ preamble=preamble,
1025
+ )
1026
+ self.ts = ts
1027
+ dwg = self.drawing
1028
+ # Put all styles in a single stylesheet (required for Inkscape 0.92)
1029
+ style = self.standard_style + ("" if style is None else style)
1030
+ dwg.defs.add(dwg.style(style))
1031
+ self.debug_box = debug_box
1032
+ self.time_scale = check_time_scale(time_scale)
1033
+ self.y_axis = check_y_axis(y_axis)
1034
+ self.x_axis = x_axis
1035
+ if x_label is None and x_axis:
1036
+ x_label = "Genome position"
1037
+ if y_label is None and y_axis:
1038
+ if time_scale == "rank":
1039
+ y_label = "Node time"
1040
+ else:
1041
+ y_label = "Time ago"
1042
+ if ts.time_units != tskit.TIME_UNITS_UNKNOWN:
1043
+ y_label += f" ({ts.time_units})"
1044
+ self.x_label = x_label
1045
+ self.y_label = y_label
1046
+ self.offsets = Offsets() if offsets is None else offsets
1047
+ self.omit_sites = omit_sites
1048
+ self.mutation_titles = {} if mutation_titles is None else mutation_titles
1049
+ self.mutations_outside_tree = set() # mutations in here get an additional class
1050
+
1051
+ def set_spacing(self, top=0, left=0, bottom=0, right=0):
1052
+ """
1053
+ Set edges, but allow space for axes etc
1054
+ """
1055
+ self.x_axis_offset = self.default_x_axis_offset
1056
+ self.y_axis_offset = self.default_y_axis_offset
1057
+ if self.x_label:
1058
+ self.x_axis_offset += self.line_height
1059
+ if self.y_label:
1060
+ self.y_axis_offset += self.line_height
1061
+ if self.x_axis:
1062
+ bottom += self.x_axis_offset
1063
+ if self.y_axis == "left":
1064
+ left = (
1065
+ self.y_axis_offset
1066
+ ) # Override user-provided values, so y-axis is at x=0
1067
+ if self.y_axis == "right":
1068
+ right = self.y_axis_offset
1069
+ self.plotbox.set_padding(top, left, bottom, right)
1070
+ if self.debug_box:
1071
+ self.root_groups["debug"] = self.dwg_base.add(
1072
+ self.drawing.g(class_="debug")
1073
+ )
1074
+ self.plotbox.draw(self.drawing, self.root_groups["debug"])
1075
+
1076
+ def get_axes(self):
1077
+ if "axes" not in self.root_groups:
1078
+ self.root_groups["axes"] = self.dwg_base.add(self.drawing.g(class_="axes"))
1079
+ return self.root_groups["axes"]
1080
+
1081
+ def draw_x_axis(
1082
+ self,
1083
+ tick_positions=None, # np.array of ax ticks below (+ above if sites is None)
1084
+ tick_labels=None, # Tick labels below axis. If None, use the position value
1085
+ tick_length_lower=default_tick_length,
1086
+ tick_length_upper=None, # If None, use the same as tick_length_lower
1087
+ site_muts=None, # A dict of site id => mutation to plot as ticks on the x axis
1088
+ alternate_dash_positions=None, # Where to alternate the axis from solid to dash
1089
+ x_regions=None, # A dict of (left, right):label items to place in boxes
1090
+ ):
1091
+ if not self.x_axis:
1092
+ return
1093
+ if alternate_dash_positions is None:
1094
+ alternate_dash_positions = np.array([])
1095
+ if x_regions is None:
1096
+ x_regions = {}
1097
+ dwg = self.drawing
1098
+ axes = self.get_axes()
1099
+ x_axis = axes.add(dwg.g(class_="x-axis"))
1100
+ if self.x_label:
1101
+ self.add_text_in_group(
1102
+ self.x_label,
1103
+ x_axis,
1104
+ pos=((self.plotbox.left + self.plotbox.right) / 2, self.plotbox.max_y),
1105
+ group_class="title",
1106
+ class_="lab",
1107
+ transform="translate(0 -11)",
1108
+ text_anchor="middle",
1109
+ )
1110
+ if len(x_regions) > 0:
1111
+ regions_group = x_axis.add(dwg.g(class_="x-regions"))
1112
+ for i, ((left, right), label) in enumerate(x_regions.items()):
1113
+ if not (0 <= left < right <= self.ts.sequence_length):
1114
+ raise ValueError(
1115
+ f"Invalid coordinates ({left} to {right}) for x-axis region"
1116
+ )
1117
+ x1 = self.x_transform(left)
1118
+ x2 = self.x_transform(right)
1119
+ y = self.plotbox.max_y - self.x_axis_offset
1120
+ region = regions_group.add(dwg.g(class_=f"r{i}"))
1121
+ region.add(
1122
+ dwg.rect((x1, y), (x2 - x1, self.line_height), class_="r{i}")
1123
+ )
1124
+ self.add_text_in_group(
1125
+ label,
1126
+ region,
1127
+ pos=((x2 + x1) / 2, y + self.line_height / 2),
1128
+ class_="lab",
1129
+ text_anchor="middle",
1130
+ )
1131
+ if tick_length_upper is None:
1132
+ tick_length_upper = tick_length_lower
1133
+ y = rnd(self.plotbox.max_y - self.x_axis_offset)
1134
+ dash_locs = np.concatenate(
1135
+ (
1136
+ [self.plotbox.left],
1137
+ self.x_transform(alternate_dash_positions),
1138
+ [self.plotbox.right],
1139
+ )
1140
+ )
1141
+ for i, (x1, x2) in enumerate(zip(dash_locs[:-1], dash_locs[1:])):
1142
+ x_axis.add(
1143
+ dwg.line(
1144
+ (rnd(x1), y),
1145
+ (rnd(x2), y),
1146
+ class_="ax-skip" if i % 2 else "ax-line",
1147
+ )
1148
+ )
1149
+ if tick_positions is not None:
1150
+ if tick_labels is None or isinstance(tick_labels, np.ndarray):
1151
+ if tick_labels is None:
1152
+ tick_labels = tick_positions
1153
+ tick_labels = create_tick_labels(tick_labels) # format integers
1154
+
1155
+ upper_length = -tick_length_upper if site_muts is None else 0
1156
+ ticks_group = x_axis.add(dwg.g(class_="ticks"))
1157
+ for pos, lab in itertools.zip_longest(tick_positions, tick_labels):
1158
+ tick = ticks_group.add(
1159
+ dwg.g(
1160
+ class_="tick",
1161
+ transform=f"translate({rnd(self.x_transform(pos))} {y})",
1162
+ )
1163
+ )
1164
+ tick.add(dwg.line((0, rnd(upper_length)), (0, rnd(tick_length_lower))))
1165
+ self.add_text_in_group(
1166
+ lab,
1167
+ tick,
1168
+ class_="lab",
1169
+ # place origin at the bottom of the tick plus a single px space
1170
+ pos=(0, tick_length_lower + 1),
1171
+ )
1172
+ if not self.omit_sites and site_muts is not None:
1173
+ # Add sites as vertical lines with overlaid mutations as upper chevrons
1174
+ for s_id, mutations in site_muts.items():
1175
+ s = self.ts.site(s_id)
1176
+ x = self.x_transform(s.position)
1177
+ site = x_axis.add(
1178
+ dwg.g(
1179
+ class_=f"site s{s.id + self.offsets.site}",
1180
+ transform=f"translate({rnd(x)} {y})",
1181
+ )
1182
+ )
1183
+ site.add(dwg.line((0, 0), (0, rnd(-tick_length_upper)), class_="sym"))
1184
+ for i, m in enumerate(reversed(mutations)):
1185
+ mutation_class = f"mut m{m.id + self.offsets.mutation}"
1186
+ if m.id in self.mutations_outside_tree:
1187
+ mutation_class += " extra"
1188
+ mut = dwg.g(class_=mutation_class)
1189
+ h = -i * 4 - 1.5
1190
+ w = tick_length_upper / 4
1191
+ # Chevron symbol
1192
+ symbol = mut.add(
1193
+ dwg.polyline(
1194
+ [
1195
+ (rnd(w), rnd(h - 2 * w)),
1196
+ (0, rnd(h)),
1197
+ (rnd(-w), rnd(h - 2 * w)),
1198
+ ],
1199
+ class_="sym",
1200
+ )
1201
+ )
1202
+ if m.id in self.mutation_titles:
1203
+ symbol.set_desc(title=self.mutation_titles[m.id])
1204
+ site.add(mut)
1205
+
1206
+ def draw_y_axis(
1207
+ self,
1208
+ ticks, # A dict of pos->label
1209
+ upper=None, # In plot coords
1210
+ lower=None, # In plot coords
1211
+ tick_length_outer=default_tick_length, # Positive means towards the outside
1212
+ gridlines=None,
1213
+ side="left", # 'left' or 'right', where the axis is drawn
1214
+ ):
1215
+ if not self.y_axis and not self.y_label:
1216
+ return
1217
+ if upper is None:
1218
+ upper = self.plotbox.top
1219
+ if lower is None:
1220
+ lower = self.plotbox.bottom
1221
+ dwg = self.drawing
1222
+ if side == "left":
1223
+ x = rnd(self.y_axis_offset)
1224
+ width = self.plotbox.right - x
1225
+ direction = -1
1226
+ text_anchor = "end"
1227
+ pos = (0, (upper + lower) / 2)
1228
+ transform = "translate(11) rotate(-90)"
1229
+ else:
1230
+ x = rnd(self.plotbox.max_x - self.y_axis_offset)
1231
+ width = x - self.plotbox.left
1232
+ direction = 1
1233
+ text_anchor = "start"
1234
+ pos = (self.plotbox.max_x, (upper + lower) / 2)
1235
+ transform = "translate(-11) rotate(90)"
1236
+ axes = self.get_axes()
1237
+ y_axis = axes.add(dwg.g(class_="y-axis"))
1238
+ if self.y_label:
1239
+ self.add_text_in_group(
1240
+ self.y_label,
1241
+ y_axis,
1242
+ pos=pos,
1243
+ group_class="title",
1244
+ class_="lab",
1245
+ text_anchor="middle",
1246
+ transform=transform,
1247
+ )
1248
+ if self.y_axis:
1249
+ y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)), class_="ax-line"))
1250
+ ticks_group = y_axis.add(dwg.g(class_="ticks"))
1251
+ tick_outside_axis = {}
1252
+ for y, label in ticks.items():
1253
+ y_pos = self.timescaling.transform(y)
1254
+ if y_pos > lower or y_pos < upper: # nb lower > upper in SVG coords
1255
+ tick_outside_axis[y] = label
1256
+ tick = ticks_group.add(
1257
+ dwg.g(class_="tick", transform=f"translate({x} {rnd(y_pos)})")
1258
+ )
1259
+ if gridlines:
1260
+ tick.add(dwg.line((0, 0), (rnd(width), 0), class_="grid"))
1261
+ tick.add(dwg.line((0, 0), (rnd(direction * tick_length_outer), 0)))
1262
+ self.add_text_in_group(
1263
+ # place the origin at the left of the tickmark plus a single px space
1264
+ label,
1265
+ tick,
1266
+ pos=(rnd(direction * (tick_length_outer + 1)), 0),
1267
+ class_="lab",
1268
+ text_anchor=text_anchor,
1269
+ )
1270
+ if len(tick_outside_axis) > 0:
1271
+ logging.warning(
1272
+ f"Ticks {tick_outside_axis} lie outside the plotted axis"
1273
+ )
1274
+
1275
+ def shade_background(
1276
+ self,
1277
+ breaks,
1278
+ tick_length_lower,
1279
+ tree_width=None,
1280
+ bottom_padding=None,
1281
+ ):
1282
+ if not self.x_axis:
1283
+ return
1284
+ if tree_width is None:
1285
+ tree_width = self.plotbox.width
1286
+ if bottom_padding is None:
1287
+ bottom_padding = self.plotbox.pad_bottom
1288
+ plot_breaks = self.x_transform(np.array(breaks))
1289
+ dwg = self.drawing
1290
+
1291
+ # For tree sequences, we need to add on the background shaded regions
1292
+ self.root_groups["background"] = self.dwg_base.add(dwg.g(class_="background"))
1293
+ y = self.image_size[1] - self.x_axis_offset - self.plotbox.top
1294
+ for i in range(1, len(breaks)):
1295
+ break_x = plot_breaks[i]
1296
+ prev_break_x = plot_breaks[i - 1]
1297
+ tree_x = i * tree_width + self.plotbox.left
1298
+ prev_tree_x = (i - 1) * tree_width + self.plotbox.left
1299
+ # Shift diagonal lines between tree & axis into the treebox a little
1300
+ diag_height = y - (self.image_size[1] - bottom_padding) + self.plotbox.top
1301
+ self.root_groups["background"].add(
1302
+ # NB: the path below draws straight diagonal lines between the tree boxes
1303
+ # and the X axis. An alternative implementation using bezier curves could
1304
+ # substitute the following for lines 2 and 4 of the path spec string
1305
+ # "l0,{box_h:g} c0,{diag_h} {rdiag_x},0 {rdiag_x},{diag_h} "
1306
+ # "c0,-{diag_h} {ldiag_x},0 {ldiag_x},-{diag_h} l0,-{box_h:g}z"
1307
+ dwg.path(
1308
+ "M{start_x:g},{top:g} l{box_w:g},0 " # Top left to top right of tree
1309
+ "l0,{box_h:g} l{rdiag_x:g},{diag_h:g} " # Down to axis
1310
+ "l0,{tick_h:g} l{ax_x:g},0 l0,-{tick_h:g} " # Between axis ticks
1311
+ "l{ldiag_x:g},-{diag_h:g} l0,-{box_h:g}z".format( # Up from axis
1312
+ top=rnd(self.plotbox.top),
1313
+ start_x=rnd(prev_tree_x),
1314
+ box_w=rnd(tree_x - prev_tree_x),
1315
+ box_h=rnd(y - diag_height),
1316
+ rdiag_x=rnd(break_x - tree_x),
1317
+ diag_h=rnd(diag_height),
1318
+ tick_h=rnd(tick_length_lower),
1319
+ ax_x=rnd(prev_break_x - break_x),
1320
+ ldiag_x=rnd(rnd(prev_tree_x) - rnd(prev_break_x)),
1321
+ )
1322
+ )
1323
+ )
1324
+
1325
+ def x_transform(self, x):
1326
+ raise NotImplementedError(
1327
+ "No transform func defined for genome pos -> plot coords"
1328
+ )
1329
+
1330
+
1331
+ class SvgTreeSequence(SvgAxisPlot):
1332
+ """
1333
+ A class to draw a tree sequence in SVG format.
1334
+
1335
+ See :meth:`TreeSequence.draw_svg` for a description of usage and parameters.
1336
+ """
1337
+
1338
+ def __init__(
1339
+ self,
1340
+ ts,
1341
+ size,
1342
+ x_scale,
1343
+ time_scale,
1344
+ node_labels,
1345
+ mutation_labels,
1346
+ root_svg_attributes,
1347
+ style,
1348
+ order,
1349
+ force_root_branch,
1350
+ symbol_size,
1351
+ x_axis,
1352
+ y_axis,
1353
+ x_label,
1354
+ y_label,
1355
+ y_ticks,
1356
+ x_regions=None,
1357
+ y_gridlines=None,
1358
+ x_lim=None,
1359
+ max_time=None,
1360
+ min_time=None,
1361
+ node_attrs=None,
1362
+ mutation_attrs=None,
1363
+ edge_attrs=None,
1364
+ node_label_attrs=None,
1365
+ mutation_label_attrs=None,
1366
+ node_titles=None,
1367
+ mutation_titles=None,
1368
+ tree_height_scale=None,
1369
+ max_tree_height=None,
1370
+ max_num_trees=None,
1371
+ title=None,
1372
+ preamble=None,
1373
+ **kwargs,
1374
+ ):
1375
+ if max_time is None and max_tree_height is not None:
1376
+ max_time = max_tree_height
1377
+ # Deprecated in 0.3.6
1378
+ warnings.warn(
1379
+ "max_tree_height is deprecated; use max_time instead",
1380
+ FutureWarning,
1381
+ stacklevel=4,
1382
+ )
1383
+ if time_scale is None and tree_height_scale is not None:
1384
+ time_scale = tree_height_scale
1385
+ # Deprecated in 0.3.6
1386
+ warnings.warn(
1387
+ "tree_height_scale is deprecated; use time_scale instead",
1388
+ FutureWarning,
1389
+ stacklevel=4,
1390
+ )
1391
+ x_lim = check_x_lim(x_lim, max_x=ts.sequence_length)
1392
+ ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1], max_num_trees)
1393
+
1394
+ use_tree = self.tree_status & OMIT == 0
1395
+ use_skipped = np.append(np.diff(self.tree_status & OMIT_MIDDLE == 0) == 1, 0)
1396
+ num_plotboxes = np.sum(np.logical_or(use_tree, use_skipped))
1397
+ if size is None:
1398
+ size = (self.default_width * int(num_plotboxes), self.default_height)
1399
+ if max_time is None:
1400
+ max_time = "ts"
1401
+ if min_time is None:
1402
+ min_time = "ts"
1403
+ # X axis shown by default
1404
+ if x_axis is None:
1405
+ x_axis = True
1406
+ super().__init__(
1407
+ ts,
1408
+ size,
1409
+ root_svg_attributes,
1410
+ style,
1411
+ svg_class="tree-sequence",
1412
+ time_scale=time_scale,
1413
+ x_axis=x_axis,
1414
+ y_axis=y_axis,
1415
+ x_label=x_label,
1416
+ y_label=y_label,
1417
+ offsets=offsets,
1418
+ mutation_titles=mutation_titles,
1419
+ preamble=preamble,
1420
+ **kwargs,
1421
+ )
1422
+ x_scale = check_x_scale(x_scale)
1423
+ order = check_order(order)
1424
+ if node_labels is None:
1425
+ node_labels = {u: str(u) for u in range(ts.num_nodes)}
1426
+ if force_root_branch is None:
1427
+ force_root_branch = any(
1428
+ any(tree.parent(mut.node) == NULL for mut in tree.mutations())
1429
+ for tree, use in zip(ts.trees(), use_tree)
1430
+ if use
1431
+ )
1432
+
1433
+ # TODO add general padding arguments following matplotlib's terminology.
1434
+ self.set_spacing(
1435
+ top=0 if title is None else self.line_height, left=20, bottom=10, right=20
1436
+ )
1437
+ subplot_size = (self.plotbox.width / num_plotboxes, self.plotbox.height)
1438
+ subplots = []
1439
+ for tree, use, summary in zip(ts.trees(), use_tree, use_skipped):
1440
+ if use:
1441
+ subplots.append(
1442
+ SvgTree(
1443
+ tree,
1444
+ size=subplot_size,
1445
+ time_scale=time_scale,
1446
+ node_labels=node_labels,
1447
+ mutation_labels=mutation_labels,
1448
+ node_titles=node_titles,
1449
+ mutation_titles=mutation_titles,
1450
+ order=order,
1451
+ force_root_branch=force_root_branch,
1452
+ symbol_size=symbol_size,
1453
+ max_time=max_time,
1454
+ min_time=min_time,
1455
+ node_attrs=node_attrs,
1456
+ mutation_attrs=mutation_attrs,
1457
+ edge_attrs=edge_attrs,
1458
+ node_label_attrs=node_label_attrs,
1459
+ mutation_label_attrs=mutation_label_attrs,
1460
+ offsets=offsets,
1461
+ # Do not plot axes on these subplots
1462
+ **kwargs, # pass though e.g. debug boxes
1463
+ )
1464
+ )
1465
+ last_used_index = tree.index
1466
+ elif summary:
1467
+ subplots.append(
1468
+ SvgSkippedPlot(
1469
+ size=subplot_size, num_skipped=tree.index - last_used_index
1470
+ )
1471
+ )
1472
+ y = self.plotbox.top
1473
+ if title is not None:
1474
+ self.add_text_in_group(
1475
+ title,
1476
+ self.drawing,
1477
+ pos=(self.plotbox.max_x / 2, 0),
1478
+ dominant_baseline="hanging",
1479
+ group_class="title",
1480
+ text_anchor="middle",
1481
+ )
1482
+ self.tree_plotbox = subplots[0].plotbox
1483
+ tree_is_used, breaks, skipbreaks = self.find_used_trees()
1484
+ self.draw_x_axis(
1485
+ x_scale,
1486
+ tree_is_used,
1487
+ breaks,
1488
+ skipbreaks,
1489
+ tick_length_lower=self.default_tick_length, # TODO - parameterize
1490
+ tick_length_upper=self.default_tick_length_site, # TODO - parameterize
1491
+ x_regions=x_regions,
1492
+ )
1493
+ y_low = self.tree_plotbox.bottom
1494
+ if y_axis is not None:
1495
+ tscales = {s.timescaling for s in subplots if s.timescaling}
1496
+ if len(tscales) > 1:
1497
+ raise ValueError(
1498
+ "Can't draw a tree sequence Y axis if trees vary in timescale"
1499
+ )
1500
+ self.timescaling = tscales.pop()
1501
+ y_low = self.timescaling.transform(self.timescaling.min_time)
1502
+ if y_ticks is None:
1503
+ used_nodes = edge_and_sample_nodes(ts, breaks[skipbreaks])
1504
+ y_ticks = np.unique(ts.nodes_time[used_nodes])
1505
+ if self.time_scale == "rank":
1506
+ # Ticks labelled by time not rank
1507
+ y_ticks = dict(enumerate(y_ticks))
1508
+
1509
+ self.draw_y_axis(
1510
+ ticks=check_y_ticks(y_ticks),
1511
+ upper=self.tree_plotbox.top,
1512
+ lower=y_low,
1513
+ tick_length_outer=self.default_tick_length,
1514
+ gridlines=y_gridlines,
1515
+ side="right" if y_axis == "right" else "left",
1516
+ )
1517
+
1518
+ subplot_x = self.plotbox.left
1519
+ container = self.get_plotbox() # Top-level TS plotbox contains all trees
1520
+ container["class"] = container["class"] + " trees"
1521
+ for subplot in subplots:
1522
+ svg_subplot = container.add(
1523
+ self.drawing.g(
1524
+ class_=subplot.svg_class,
1525
+ transform=f"translate({rnd(subplot_x)} {y})",
1526
+ )
1527
+ )
1528
+ for svg_items in subplot.root_groups.values():
1529
+ svg_subplot.add(svg_items)
1530
+ subplot_x += subplot.image_size[0]
1531
+
1532
+ def find_used_trees(self):
1533
+ """
1534
+ Return a boolean array of which trees are actually plotted,
1535
+ a list of which breakpoints are used to transition between plotted trees,
1536
+ and a 2 x n array (often n=0) of indexes into these breakpoints delimiting
1537
+ the regions that should be plotted as "skipped"
1538
+ """
1539
+ tree_is_used = (self.tree_status & OMIT) != OMIT
1540
+ break_used_as_tree_left = np.append(tree_is_used, False)
1541
+ break_used_as_tree_right = np.insert(tree_is_used, 0, False)
1542
+ break_used = np.logical_or(break_used_as_tree_left, break_used_as_tree_right)
1543
+ all_breaks = self.ts.breakpoints(True)
1544
+ used_breaks = all_breaks[break_used]
1545
+ mark_skip_transitions = np.concatenate(
1546
+ ([False], np.diff(self.tree_status & OMIT_MIDDLE) != 0, [False])
1547
+ )
1548
+ skipregion_indexes = np.where(mark_skip_transitions[break_used])[0]
1549
+ assert len(skipregion_indexes) % 2 == 0 # all skipped regions have start, end
1550
+ return tree_is_used, used_breaks, skipregion_indexes.reshape((-1, 2))
1551
+
1552
+ def draw_x_axis(
1553
+ self,
1554
+ x_scale,
1555
+ tree_is_used,
1556
+ breaks,
1557
+ skipbreaks,
1558
+ x_regions,
1559
+ tick_length_lower=SvgAxisPlot.default_tick_length,
1560
+ tick_length_upper=SvgAxisPlot.default_tick_length_site,
1561
+ ):
1562
+ """
1563
+ Add extra functionality to the original draw_x_axis method in SvgAxisPlot,
1564
+ to account for the background shading that is displayed in a tree sequence
1565
+ and in case trees are omitted from the middle of the tree sequence
1566
+ """
1567
+ if not self.x_axis and not self.x_label:
1568
+ return
1569
+ if x_scale == "physical":
1570
+ # In a tree sequence plot, the x_transform is used for the ticks, background
1571
+ # shading positions, and sites along the x-axis. Each tree will have its own
1572
+ # separate x_transform function for node positions within the tree.
1573
+
1574
+ # For a plot with a break on the x-axis (representing "skipped" trees), the
1575
+ # x_transform is a piecewise function. We need to identify the breakpoints
1576
+ # where the x-scale transitions from the standard scale to the scale(s) used
1577
+ # within a skipped region
1578
+
1579
+ skipregion_plot_width = self.tree_plotbox.width
1580
+ skipregion_span = np.diff(breaks[skipbreaks]).T[0]
1581
+ std_scale = (
1582
+ self.plotbox.width - skipregion_plot_width * len(skipregion_span)
1583
+ ) / (breaks[-1] - breaks[0] - np.sum(skipregion_span))
1584
+ skipregion_pos = breaks[skipbreaks].flatten()
1585
+ genome_pos = np.concatenate(([breaks[0]], skipregion_pos, [breaks[-1]]))
1586
+ plot_step = np.full(len(genome_pos) - 1, skipregion_plot_width)
1587
+ plot_step[::2] = std_scale * np.diff(genome_pos)[::2]
1588
+ plot_pos = np.cumsum(np.insert(plot_step, 0, self.plotbox.left))
1589
+ # Convert to slope + intercept form
1590
+ slope = np.diff(plot_pos) / np.diff(genome_pos)
1591
+ intercept = plot_pos[1:] - slope * genome_pos[1:]
1592
+ self.x_transform = lambda y: (
1593
+ y * slope[np.searchsorted(skipregion_pos, y)]
1594
+ + intercept[np.searchsorted(skipregion_pos, y)]
1595
+ )
1596
+ tick_positions = breaks
1597
+ site_muts = {
1598
+ s.id: s.mutations
1599
+ for tree, use in zip(self.ts.trees(), tree_is_used)
1600
+ for s in tree.sites()
1601
+ if use
1602
+ }
1603
+
1604
+ self.shade_background(
1605
+ breaks,
1606
+ tick_length_lower,
1607
+ self.tree_plotbox.max_x,
1608
+ self.plotbox.pad_bottom + self.tree_plotbox.pad_bottom,
1609
+ )
1610
+ else:
1611
+ # For a treewise plot, the only time the x_transform is used is to apply
1612
+ # to tick positions, so simply use positions 0..num_used_breaks for the
1613
+ # positions, and a simple transform
1614
+ self.x_transform = (
1615
+ lambda x: self.plotbox.left + x / (len(breaks) - 1) * self.plotbox.width
1616
+ )
1617
+ tick_positions = np.arange(len(breaks))
1618
+
1619
+ site_muts = None # It doesn't make sense to plot sites for "treewise" plots
1620
+ tick_length_upper = None # No sites plotted, so use the default upper tick
1621
+ if x_regions is not None and len(x_regions) > 0:
1622
+ raise ValueError("x_regions are not supported for treewise plots")
1623
+
1624
+ # NB: no background shading needed if x_scale is "treewise"
1625
+
1626
+ skipregion_pos = skipbreaks.flatten()
1627
+
1628
+ first_tick = 1 if np.any(self.tree_status[tree_is_used] & LEFT_CLIP) else 0
1629
+ last_tick = -1 if np.any(self.tree_status[tree_is_used] & RIGHT_CLIP) else None
1630
+
1631
+ super().draw_x_axis(
1632
+ tick_positions=tick_positions[first_tick:last_tick],
1633
+ tick_labels=breaks[first_tick:last_tick],
1634
+ tick_length_lower=tick_length_lower,
1635
+ tick_length_upper=tick_length_upper,
1636
+ site_muts=site_muts,
1637
+ alternate_dash_positions=skipregion_pos,
1638
+ x_regions=x_regions,
1639
+ )
1640
+
1641
+
1642
+ class SvgTree(SvgAxisPlot):
1643
+ """
1644
+ A class to draw a tree in SVG format.
1645
+
1646
+ See :meth:`Tree.draw_svg` for a description of usage and frequently used parameters.
1647
+ """
1648
+
1649
+ PolytomyLine = collections.namedtuple(
1650
+ "PolytomyLine", "num_branches, num_samples, line_pos"
1651
+ )
1652
+ margin_left = 20
1653
+ margin_right = 20
1654
+ margin_top = 10 # oldest point is line_height below or 2*line_height if title given
1655
+ margin_bottom = 15 # youngest plot points are line_height above this bottom margin
1656
+
1657
+ def __init__(
1658
+ self,
1659
+ tree,
1660
+ size=None,
1661
+ max_time=None,
1662
+ min_time=None,
1663
+ max_tree_height=None,
1664
+ node_labels=None,
1665
+ mutation_labels=None,
1666
+ node_titles=None,
1667
+ mutation_titles=None,
1668
+ root_svg_attributes=None,
1669
+ style=None,
1670
+ order=None,
1671
+ force_root_branch=None,
1672
+ symbol_size=None,
1673
+ x_axis=None,
1674
+ y_axis=None,
1675
+ x_label=None,
1676
+ y_label=None,
1677
+ title=None,
1678
+ x_regions=None,
1679
+ y_ticks=None,
1680
+ y_gridlines=None,
1681
+ all_edge_mutations=None,
1682
+ time_scale=None,
1683
+ tree_height_scale=None,
1684
+ node_attrs=None,
1685
+ mutation_attrs=None,
1686
+ edge_attrs=None,
1687
+ node_label_attrs=None,
1688
+ mutation_label_attrs=None,
1689
+ offsets=None,
1690
+ omit_sites=None,
1691
+ pack_untracked_polytomies=None,
1692
+ preamble=None,
1693
+ **kwargs,
1694
+ ):
1695
+ if max_time is None and max_tree_height is not None:
1696
+ max_time = max_tree_height
1697
+ # Deprecated in 0.3.6
1698
+ warnings.warn(
1699
+ "max_tree_height is deprecated; use max_time instead",
1700
+ FutureWarning,
1701
+ stacklevel=4,
1702
+ )
1703
+ if time_scale is None and tree_height_scale is not None:
1704
+ time_scale = tree_height_scale
1705
+ # Deprecated in 0.3.6
1706
+ warnings.warn(
1707
+ "tree_height_scale is deprecated; use time_scale instead",
1708
+ FutureWarning,
1709
+ stacklevel=4,
1710
+ )
1711
+ if size is None:
1712
+ size = (self.default_width, self.default_height)
1713
+ if symbol_size is None:
1714
+ symbol_size = 6
1715
+ self.symbol_size = symbol_size
1716
+ self.pack_untracked_polytomies = pack_untracked_polytomies
1717
+ ts = tree.tree_sequence
1718
+ tree_index = tree.index
1719
+ if offsets is not None:
1720
+ tree_index += offsets.tree
1721
+ super().__init__(
1722
+ ts,
1723
+ size,
1724
+ root_svg_attributes,
1725
+ style,
1726
+ svg_class=f"tree t{tree_index}",
1727
+ time_scale=time_scale,
1728
+ x_axis=x_axis,
1729
+ y_axis=y_axis,
1730
+ x_label=x_label,
1731
+ y_label=y_label,
1732
+ offsets=offsets,
1733
+ omit_sites=omit_sites,
1734
+ preamble=preamble,
1735
+ **kwargs,
1736
+ )
1737
+ self.tree = tree
1738
+ if order is None or isinstance(order, str):
1739
+ # Can't use the Tree.postorder array as we need minlex
1740
+ self.postorder_nodes = list(tree.nodes(order=check_order(order)))
1741
+ else:
1742
+ # Currently undocumented feature: we can pass a (postorder) list
1743
+ # of nodes to plot, which allows us to draw a subset of nodes, or
1744
+ # stop traversing certain subtrees
1745
+ self.postorder_nodes = order
1746
+
1747
+ # Create some instance variables for later use in plotting
1748
+ self.node_mutations = collections.defaultdict(list)
1749
+ self.edge_attrs = {}
1750
+ self.node_attrs = {}
1751
+ self.node_label_attrs = {}
1752
+ self.mutation_attrs = {}
1753
+ self.mutation_label_attrs = {}
1754
+ self.node_titles = {} if node_titles is None else node_titles
1755
+ self.mutation_titles = {} if mutation_titles is None else mutation_titles
1756
+ self.mutations_over_roots = False
1757
+ # mutations collected per node
1758
+ nodes = set(tree.nodes())
1759
+ unplotted = []
1760
+ if not omit_sites:
1761
+ for site in tree.sites():
1762
+ for mutation in site.mutations:
1763
+ if mutation.node in nodes:
1764
+ self.node_mutations[mutation.node].append(mutation)
1765
+ if tree.parent(mutation.node) == NULL:
1766
+ self.mutations_over_roots = True
1767
+ else:
1768
+ unplotted.append(mutation.id + self.offsets.mutation)
1769
+ if len(unplotted) > 0:
1770
+ warnings.warn(
1771
+ f"Mutations {unplotted} are above nodes which are not present in the "
1772
+ "displayed tree, so are not plotted on the topology.",
1773
+ UserWarning,
1774
+ stacklevel=2,
1775
+ )
1776
+ self.left_extent = tree.interval.left
1777
+ self.right_extent = tree.interval.right
1778
+ if not omit_sites and all_edge_mutations:
1779
+ tree_left = tree.interval.left
1780
+ tree_right = tree.interval.right
1781
+ edge_left = ts.tables.edges.left
1782
+ edge_right = ts.tables.edges.right
1783
+ node_edges = tree.edge_array
1784
+ # whittle mutations down so we only need look at those above the tree nodes
1785
+ mut_t = ts.tables.mutations
1786
+ focal_mutations = np.isin(mut_t.node, np.fromiter(nodes, mut_t.node.dtype))
1787
+ mutation_nodes = mut_t.node[focal_mutations]
1788
+ mutation_positions = ts.tables.sites.position[mut_t.site][focal_mutations]
1789
+ mutation_ids = np.arange(ts.num_mutations, dtype=int)[focal_mutations]
1790
+ for m_id, node, pos in zip(
1791
+ mutation_ids, mutation_nodes, mutation_positions
1792
+ ):
1793
+ curr_edge = node_edges[node]
1794
+ if curr_edge >= 0:
1795
+ if (
1796
+ edge_left[curr_edge] <= pos < tree_left
1797
+ ): # Mutation on this edge but to left of plotted tree
1798
+ self.node_mutations[node].append(ts.mutation(m_id))
1799
+ self.mutations_outside_tree.add(m_id)
1800
+ self.left_extent = min(self.left_extent, pos)
1801
+ elif (
1802
+ tree_right <= pos < edge_right[curr_edge]
1803
+ ): # Mutation on this edge but to right of plotted tree
1804
+ self.node_mutations[node].append(ts.mutation(m_id))
1805
+ self.mutations_outside_tree.add(m_id)
1806
+ self.right_extent = max(self.right_extent, pos)
1807
+ if self.right_extent != tree.interval.right:
1808
+ # Use nextafter so extent of plotting incorporates the mutation
1809
+ self.right_extent = np.nextafter(
1810
+ self.right_extent, self.right_extent + 1
1811
+ )
1812
+ # attributes for symbols
1813
+ half_symbol_size = f"{rnd(symbol_size / 2):g}"
1814
+ symbol_size = f"{rnd(symbol_size):g}"
1815
+ for u in tree.nodes():
1816
+ self.edge_attrs[u] = {}
1817
+ if edge_attrs is not None and u in edge_attrs:
1818
+ self.edge_attrs[u].update(edge_attrs[u])
1819
+ if tree.is_sample(u):
1820
+ # a square: set bespoke svgwrite params
1821
+ self.node_attrs[u] = {
1822
+ "size": (symbol_size,) * 2,
1823
+ "insert": ("-" + half_symbol_size,) * 2,
1824
+ }
1825
+ else:
1826
+ # a circle: set bespoke svgwrite param `centre` and default radius
1827
+ self.node_attrs[u] = {"center": (0, 0), "r": half_symbol_size}
1828
+ if node_attrs is not None and u in node_attrs:
1829
+ self.node_attrs[u].update(node_attrs[u])
1830
+ add_class(self.node_attrs[u], "sym") # class 'sym' for symbol
1831
+ label = ""
1832
+ if node_labels is None:
1833
+ label = str(u)
1834
+ elif u in node_labels:
1835
+ label = str(node_labels[u])
1836
+ self.node_label_attrs[u] = {"text": label}
1837
+ add_class(self.node_label_attrs[u], "lab") # class 'lab' for label
1838
+ if node_label_attrs is not None and u in node_label_attrs:
1839
+ self.node_label_attrs[u].update(node_label_attrs[u])
1840
+ for _, mutations in self.node_mutations.items():
1841
+ for mutation in mutations:
1842
+ m = mutation.id + self.offsets.mutation
1843
+ # We need to offset the mutation symbol so that it's centred
1844
+ self.mutation_attrs[m] = {
1845
+ "d": "M -{0},-{0} l {1},{1} M -{0},{0} l {1},-{1}".format(
1846
+ half_symbol_size, symbol_size
1847
+ )
1848
+ }
1849
+ if mutation_attrs is not None and m in mutation_attrs:
1850
+ self.mutation_attrs[m].update(mutation_attrs[m])
1851
+ add_class(self.mutation_attrs[m], "sym") # class 'sym' for symbol
1852
+ label = ""
1853
+ if mutation_labels is None:
1854
+ label = str(m)
1855
+ elif m in mutation_labels:
1856
+ label = str(mutation_labels[m])
1857
+ self.mutation_label_attrs[m] = {"text": label}
1858
+ if mutation_label_attrs is not None and m in mutation_label_attrs:
1859
+ self.mutation_label_attrs[m].update(mutation_label_attrs[m])
1860
+ add_class(self.mutation_label_attrs[m], "lab")
1861
+
1862
+ self.set_spacing(
1863
+ top=self.margin_top + (0 if title is None else self.line_height),
1864
+ left=self.margin_left,
1865
+ bottom=self.margin_bottom,
1866
+ right=self.margin_right,
1867
+ )
1868
+ if title is not None:
1869
+ self.add_text_in_group(
1870
+ title,
1871
+ self.drawing,
1872
+ pos=(self.plotbox.max_x / 2, 0),
1873
+ dominant_baseline="hanging",
1874
+ group_class="title",
1875
+ text_anchor="middle",
1876
+ )
1877
+
1878
+ self.assign_x_coordinates()
1879
+ self.assign_y_coordinates(max_time, min_time, force_root_branch)
1880
+ tick_length_lower = self.default_tick_length # TODO - parameterize
1881
+ tick_length_upper = self.default_tick_length_site # TODO - parameterize
1882
+ if all_edge_mutations:
1883
+ self.shade_background(tree.interval, tick_length_lower)
1884
+
1885
+ first_site, last_site = np.searchsorted(
1886
+ self.ts.tables.sites.position, [self.left_extent, self.right_extent]
1887
+ )
1888
+ site_muts = {site_id: [] for site_id in range(first_site, last_site)}
1889
+ # Only use mutations plotted on the tree (not necessarily all at the site)
1890
+ for muts in self.node_mutations.values():
1891
+ for mut in muts:
1892
+ site_muts[mut.site].append(mut)
1893
+
1894
+ self.draw_x_axis(
1895
+ tick_positions=np.array(tree.interval),
1896
+ tick_length_lower=tick_length_lower,
1897
+ tick_length_upper=tick_length_upper,
1898
+ site_muts=site_muts,
1899
+ x_regions=x_regions,
1900
+ )
1901
+ if y_ticks is None:
1902
+ y_ticks = {h: ts.node(u).time for u, h in sorted(self.node_height.items())}
1903
+
1904
+ self.draw_y_axis(
1905
+ ticks=check_y_ticks(y_ticks),
1906
+ lower=self.timescaling.transform(self.timescaling.min_time),
1907
+ tick_length_outer=self.default_tick_length,
1908
+ gridlines=y_gridlines,
1909
+ side="right" if y_axis == "right" else "left",
1910
+ )
1911
+ self.draw_tree()
1912
+
1913
+ def process_mutations_over_node(self, u, low_bound, high_bound, ignore_times=False):
1914
+ """
1915
+ Sort the self.node_mutations array for a given node ``u`` in reverse time order.
1916
+ The main complication is with UNKNOWN_TIME values: we replace these with times
1917
+ spaced between the low & high bounds (this is always done if ignore_times=True).
1918
+ We do not currently allow a mix of known & unknown mutation times in a tree
1919
+ sequence, which makes the logic easy. If we were to allow it, more complex
1920
+ logic can be neatly encapsulated in this method.
1921
+ """
1922
+ mutations = self.node_mutations[u]
1923
+ time_unknown = [util.is_unknown_time(m.time) for m in mutations]
1924
+ if all(time_unknown) or ignore_times is True:
1925
+ # sort by site then within site by parent: will end up with oldest first
1926
+ mutations.sort(key=operator.attrgetter("site", "parent"))
1927
+ diff = high_bound - low_bound
1928
+ for i in range(len(mutations)):
1929
+ mutations[i].time = high_bound - diff * (i + 1) / (len(mutations) + 1)
1930
+ else:
1931
+ assert not any(time_unknown)
1932
+ mutations.sort(key=operator.attrgetter("time"), reverse=True)
1933
+
1934
+ def assign_y_coordinates(
1935
+ self,
1936
+ max_time,
1937
+ min_time,
1938
+ force_root_branch,
1939
+ bottom_space=SvgAxisPlot.line_height,
1940
+ top_space=SvgAxisPlot.line_height,
1941
+ ):
1942
+ """
1943
+ Create a self.node_height dict, a self.timescaling instance and
1944
+ self.min_root_branch_plot_length for use in plotting. Allow extra space within
1945
+ the plotbox, at the bottom for leaf labels, and (potentially, if no root
1946
+ branches are plotted) above the topmost root node for root labels.
1947
+ """
1948
+ max_time = check_max_time(max_time, self.time_scale != "rank")
1949
+ min_time = check_min_time(min_time, self.time_scale != "rank")
1950
+ node_time = self.ts.nodes_time
1951
+ mut_time = self.ts.mutations_time
1952
+ root_branch_len = 0
1953
+ if self.time_scale == "rank":
1954
+ t = np.zeros_like(node_time)
1955
+ if max_time == "tree":
1956
+ # We only rank the times within the tree in this case.
1957
+ for u in self.node_x_coord.keys():
1958
+ t[u] = node_time[u]
1959
+ else:
1960
+ # only rank the nodes that are actually referenced in the edge table
1961
+ # (non-referenced nodes could occur if the user specifies x_lim values)
1962
+ # However, we do include nodes in trees that have been skipped
1963
+ use_time = edge_and_sample_nodes(self.ts)
1964
+ t[use_time] = node_time[use_time]
1965
+ node_time = t
1966
+ times = np.unique(node_time[node_time <= self.ts.max_root_time])
1967
+ max_node_height = len(times)
1968
+ depth = {t: j for j, t in enumerate(times)}
1969
+ if self.mutations_over_roots or force_root_branch:
1970
+ root_branch_len = 1 # Will get scaled later
1971
+ max_time = max(depth.values()) + root_branch_len
1972
+ if min_time in (None, "tree", "ts"):
1973
+ assert min(depth.values()) == 0
1974
+ min_time = 0
1975
+ # In pathological cases, all the nodes are at the same time
1976
+ if max_time == min_time:
1977
+ max_time = min_time + 1
1978
+ self.node_height = {
1979
+ u: depth[node_time[u]] for u in self.node_x_coord.keys()
1980
+ }
1981
+ for u in self.node_mutations.keys():
1982
+ if u in self.node_height:
1983
+ parent = self.tree.parent(u)
1984
+ if parent == NULL:
1985
+ top = self.node_height[u] + root_branch_len
1986
+ else:
1987
+ top = depth[node_time[parent]]
1988
+ self.process_mutations_over_node(
1989
+ u, self.node_height[u], top, ignore_times=True
1990
+ )
1991
+ else:
1992
+ assert self.time_scale in ["time", "log_time"]
1993
+ self.node_height = {u: node_time[u] for u in self.node_x_coord.keys()}
1994
+ if max_time == "tree":
1995
+ max_node_height = max(self.node_height.values())
1996
+ max_mut_height = np.nanmax(
1997
+ [0] + [mut.time for m in self.node_mutations.values() for mut in m]
1998
+ )
1999
+ max_time = max(max_node_height, max_mut_height) # Reuse variable
2000
+ elif max_time == "ts":
2001
+ max_node_height = self.ts.max_root_time
2002
+ max_mut_height = np.nanmax(np.append(mut_time, 0))
2003
+ max_time = max(max_node_height, max_mut_height) # Reuse variable
2004
+ else:
2005
+ max_node_height = max_time
2006
+ if min_time == "tree":
2007
+ min_time = min(self.node_height.values())
2008
+ # don't need to check mutation times, as they must be above a node
2009
+ elif min_time == "ts":
2010
+ min_time = np.min(self.ts.nodes_time[edge_and_sample_nodes(self.ts)])
2011
+ # In pathological cases, all the nodes are at the same time
2012
+ if min_time == max_time:
2013
+ max_time = min_time + 1
2014
+ if self.mutations_over_roots or force_root_branch:
2015
+ # Define a minimum root branch length, after transformation if necessary
2016
+ if self.time_scale != "log_time":
2017
+ root_branch_len = (max_time - min_time) * self.root_branch_fraction
2018
+ else:
2019
+ max_plot_y = np.log(max_time + 1)
2020
+ diff_plot_y = max_plot_y - np.log(min_time + 1)
2021
+ root_plot_y = max_plot_y + diff_plot_y * self.root_branch_fraction
2022
+ root_branch_len = np.exp(root_plot_y) - 1 - max_time
2023
+ # If necessary, allow for this extra branch in max_time
2024
+ if max_node_height + root_branch_len > max_time:
2025
+ max_time = max_node_height + root_branch_len
2026
+ for u in self.node_mutations.keys():
2027
+ if u in self.node_height:
2028
+ parent = self.tree.parent(u)
2029
+ if parent == NULL:
2030
+ # This is a root: if muts have no times we specify an upper time
2031
+ top = self.node_height[u] + root_branch_len
2032
+ else:
2033
+ top = node_time[parent]
2034
+ self.process_mutations_over_node(u, self.node_height[u], top)
2035
+
2036
+ assert float(max_time) == max_time
2037
+ assert float(min_time) == min_time
2038
+ # Add extra space above the top and below the bottom of the tree to keep the
2039
+ # node labels within the plotbox (but top label space not needed if the
2040
+ # existence of a root branch pushes the whole tree + labels downwards anyway)
2041
+ top_space = 0 if root_branch_len > 0 else top_space
2042
+ self.timescaling = Timescaling(
2043
+ max_time=max_time,
2044
+ min_time=min_time,
2045
+ plot_min=self.plotbox.height + self.plotbox.top - bottom_space,
2046
+ plot_range=self.plotbox.height - top_space - bottom_space,
2047
+ use_log_transform=(self.time_scale == "log_time"),
2048
+ )
2049
+
2050
+ # Calculate default root branch length to use (in plot coords). This is a
2051
+ # minimum, as branches with deep root mutations could be longer
2052
+ self.min_root_branch_plot_length = self.timescaling.transform(
2053
+ self.timescaling.max_time
2054
+ ) - self.timescaling.transform(self.timescaling.max_time + root_branch_len)
2055
+
2056
+ def assign_x_coordinates(self):
2057
+ # Set up transformation for genome positions
2058
+ self.x_transform = lambda x: (
2059
+ (x - self.left_extent)
2060
+ / (self.right_extent - self.left_extent)
2061
+ * self.plotbox.width
2062
+ + self.plotbox.left
2063
+ )
2064
+ # Set up x positions for nodes
2065
+ node_xpos = {}
2066
+ untracked_children = collections.defaultdict(list)
2067
+ self.extra_line = {} # To store a dotted line to represent polytomies
2068
+ leaf_x = 0 # First leaf starts at x=1, to give some space between Y axis & leaf
2069
+ tree = self.tree
2070
+ prev = tree.virtual_root
2071
+ for u in self.postorder_nodes:
2072
+ parent = tree.parent(u)
2073
+ omit = self.pack_untracked_polytomies and tree.num_tracked_samples(u) == 0
2074
+ if parent == prev:
2075
+ raise ValueError("Nodes must be passed in postorder to Tree.draw_svg()")
2076
+ is_tip = tree.parent(prev) != u
2077
+ if is_tip:
2078
+ if not omit:
2079
+ leaf_x += 1
2080
+ node_xpos[u] = leaf_x
2081
+ elif not omit:
2082
+ # Untracked children are available for packing into a polytomy summary
2083
+ untracked_children = []
2084
+ if self.pack_untracked_polytomies:
2085
+ untracked_children += [
2086
+ c for c in tree.children(u) if tree.num_tracked_samples(c) == 0
2087
+ ]
2088
+ child_x = [node_xpos[c] for c in tree.children(u) if c in node_xpos]
2089
+ if len(untracked_children) > 0:
2090
+ if len(untracked_children) <= 1:
2091
+ # If only a single non-focal lineage, treat it as a condensed tip
2092
+ for child in untracked_children:
2093
+ leaf_x += 1
2094
+ node_xpos[child] = leaf_x
2095
+ child_x.append(leaf_x)
2096
+ else:
2097
+ # Otherwise show a horizontal line with the number of lineages
2098
+ # Extra length of line is equal to log of the polytomy size
2099
+ self.extra_line[u] = self.PolytomyLine(
2100
+ len(untracked_children),
2101
+ sum(tree.num_samples(v) for v in untracked_children),
2102
+ [leaf_x, leaf_x + 1 + np.log(len(untracked_children))],
2103
+ )
2104
+ child_x.append(leaf_x + 1)
2105
+ leaf_x = self.extra_line[u].line_pos[1]
2106
+ assert len(child_x) != 0 # Must have prev hit somethng defined as a tip
2107
+ if len(child_x) == 1:
2108
+ node_xpos[u] = child_x[0]
2109
+ else:
2110
+ a = min(child_x)
2111
+ b = max(child_x)
2112
+ node_xpos[u] = a + (b - a) / 2
2113
+ prev = u
2114
+ # Now rescale to the plot width: leaf_x is the maximum value of the last leaf
2115
+ if len(node_xpos) > 0:
2116
+ scale = self.plotbox.width / leaf_x
2117
+ lft = self.plotbox.left - scale / 2
2118
+ self.node_x_coord = {k: lft + v * scale for k, v in node_xpos.items()}
2119
+ for v in self.extra_line.values():
2120
+ for i in range(len(v.line_pos)):
2121
+ v.line_pos[i] = lft + v.line_pos[i] * scale
2122
+
2123
+ def info_classes(self, focal_node_id):
2124
+ """
2125
+ For a focal node id, return a set of classes that encode this useful information:
2126
+ "a<X>" or "root": where <X> == id of immediate ancestor (parent) node
2127
+ "i<I>": where <I> == individual id
2128
+ "p<P>": where <P> == population id
2129
+ "n<Y>": where <Y> == focal node id
2130
+ "m<A>": where <A> == mutation id
2131
+ "s<B>": where <B> == site id of all mutations
2132
+ "c<N>" or "leaf": where <N> == number of direct children of this node
2133
+ """
2134
+ # Add a new group for each node, and give it classes for css targetting
2135
+ focal_node = self.ts.node(focal_node_id)
2136
+ classes = set()
2137
+ classes.add(f"node n{focal_node_id}")
2138
+ if focal_node.individual != NULL:
2139
+ classes.add(f"i{focal_node.individual}")
2140
+ if focal_node.population != NULL:
2141
+ classes.add(f"p{focal_node.population}")
2142
+ v = self.tree.parent(focal_node_id)
2143
+ if v == NULL:
2144
+ classes.add("root")
2145
+ else:
2146
+ classes.add(f"a{v}")
2147
+ if self.tree.is_sample(focal_node_id):
2148
+ classes.add("sample")
2149
+ if self.tree.is_leaf(focal_node_id):
2150
+ classes.add("leaf")
2151
+ else:
2152
+ classes.add(f"c{self.tree.num_children(focal_node_id)}")
2153
+ for mutation in self.node_mutations[focal_node_id]:
2154
+ # Adding mutations and sites above this node allows identification
2155
+ # of the tree under any specific mutation
2156
+ classes.add(f"m{mutation.id + self.offsets.mutation}")
2157
+ classes.add(f"s{mutation.site + self.offsets.site}")
2158
+ return sorted(classes)
2159
+
2160
+ def text_transform(self, position, dy=0):
2161
+ line_h = self.text_height
2162
+ sym_sz = self.symbol_size
2163
+ transforms = {
2164
+ "below": f"translate(0 {rnd(line_h - sym_sz / 2 + dy)})",
2165
+ "above": f"translate(0 {rnd(-(line_h - sym_sz / 2) + dy)})",
2166
+ "above_left": f"translate({rnd(-sym_sz / 2)} {rnd(-line_h / 2 + dy)})",
2167
+ "above_right": f"translate({rnd(sym_sz / 2)} {-rnd(line_h / 2 + dy)})",
2168
+ "left": f"translate({-rnd(2 + sym_sz / 2)} {rnd(dy)})",
2169
+ "right": f"translate({rnd(2 + sym_sz / 2)} {rnd(dy)})",
2170
+ }
2171
+ return transforms[position]
2172
+
2173
+ def draw_tree(self):
2174
+ # Note: the displayed tree may not be the same as self.tree, e.g. if the nodes
2175
+ # have been collapsed, or a subtree is being displayed. The node_x_coord
2176
+ # dictionary keys gives the nodes of the displayed tree, in postorder.
2177
+ NodeDrawInfo = collections.namedtuple("NodeDrawInfo", ["pos", "is_tip"])
2178
+ dwg = self.drawing
2179
+ tree = self.tree
2180
+ left_child = get_left_child(tree, self.postorder_nodes)
2181
+ parent_array = tree.parent_array
2182
+ edge_array = tree.edge_array
2183
+
2184
+ node_info = {}
2185
+ roots = [] # Roots of the displated tree
2186
+ prev = tree.virtual_root
2187
+ for u, x in self.node_x_coord.items(): # Node ids `u` returned in postorder
2188
+ node_info[u] = NodeDrawInfo(
2189
+ pos=np.array([x, self.timescaling.transform(self.node_height[u])]),
2190
+ # Detect if this is a "tip" in the displayed tree, even if
2191
+ # it is not a leaf in the original tree, by looking at the prev parent
2192
+ is_tip=(parent_array[prev] != u),
2193
+ )
2194
+ prev = u
2195
+ if parent_array[u] not in self.node_x_coord:
2196
+ roots.append(u)
2197
+ # Iterate over displayed nodes, adding groups to reflect the tree hierarchy
2198
+ stack = []
2199
+ for u in roots:
2200
+ x, y = node_info[u].pos
2201
+ grp = dwg.g(
2202
+ class_=" ".join(self.info_classes(u)),
2203
+ transform=f"translate({rnd(x)} {rnd(y)})",
2204
+ )
2205
+ stack.append((u, self.get_plotbox().add(grp)))
2206
+
2207
+ # Preorder traversal, so we can create nested groups
2208
+ while len(stack) > 0:
2209
+ u, curr_svg_group = stack.pop()
2210
+ pu, is_tip = node_info[u]
2211
+ for focal in tree.children(u):
2212
+ if focal not in node_info:
2213
+ continue
2214
+ fx, fy = node_info[focal].pos - pu
2215
+ new_svg_group = curr_svg_group.add(
2216
+ dwg.g(
2217
+ class_=" ".join(self.info_classes(focal)),
2218
+ transform=f"translate({rnd(fx)} {rnd(fy)})",
2219
+ )
2220
+ )
2221
+ stack.append((focal, new_svg_group))
2222
+
2223
+ o = (0, 0)
2224
+ v = parent_array[u]
2225
+
2226
+ # Add polytomy line if necessary
2227
+ if u in self.extra_line:
2228
+ info = self.extra_line[u]
2229
+ x2 = info.line_pos[1] - pu[0]
2230
+ poly = dwg.g(class_="polytomy")
2231
+ poly.add(
2232
+ dwg.line(
2233
+ start=(0, 0),
2234
+ end=(x2, 0),
2235
+ )
2236
+ )
2237
+ label = dwg.text(
2238
+ f"+{info.num_samples}/{bold_integer(info.num_branches)}",
2239
+ font_style="italic",
2240
+ x=[rnd(x2)],
2241
+ dy=[rnd(-self.text_height / 10)], # make the plus sign line up
2242
+ text_anchor="end",
2243
+ )
2244
+ label.set_desc(
2245
+ title=(
2246
+ f"This polytomy has {info.num_branches} additional branches, "
2247
+ f"leading to a total of {info.num_samples} descendant samples"
2248
+ )
2249
+ )
2250
+ poly.add(label)
2251
+ curr_svg_group.add(poly)
2252
+
2253
+ # Add edge above node first => on layer underneath anything else
2254
+ draw_edge_above_node = False
2255
+ try:
2256
+ dx, dy = node_info[v].pos - pu
2257
+ draw_edge_above_node = True
2258
+ except KeyError:
2259
+ # Must be a root
2260
+ root_branch_l = self.min_root_branch_plot_length
2261
+ if root_branch_l > 0:
2262
+ if len(self.node_mutations[u]) > 0:
2263
+ mtop = self.timescaling.transform(
2264
+ self.node_mutations[u][0].time
2265
+ )
2266
+ root_branch_l = max(root_branch_l, pu[1] - mtop)
2267
+ dx, dy = 0, -root_branch_l
2268
+ draw_edge_above_node = True
2269
+ if draw_edge_above_node:
2270
+ edge_id_class = (
2271
+ "root" if edge_array[u] == tskit.NULL else f"e{edge_array[u]}"
2272
+ )
2273
+ add_class(self.edge_attrs[u], f"edge {edge_id_class}")
2274
+ path = dwg.path(
2275
+ [("M", o), ("V", rnd(dy)), ("H", rnd(dx))], **self.edge_attrs[u]
2276
+ )
2277
+ curr_svg_group.add(path)
2278
+
2279
+ # Add mutation symbols + labels
2280
+ for mutation in self.node_mutations[u]:
2281
+ # TODO get rid of these manual positioning tweaks and add them
2282
+ # as offsets the user can access via a transform or something.
2283
+ dy = self.timescaling.transform(mutation.time) - pu[1]
2284
+ mutation_id = mutation.id + self.offsets.mutation
2285
+ mutation_class = (
2286
+ f"mut m{mutation_id} " f"s{mutation.site + self.offsets.site}"
2287
+ )
2288
+ # Use the real mutation ID here, since we are referencing into the ts
2289
+ if util.is_unknown_time(self.ts.mutation(mutation.id).time):
2290
+ mutation_class += " unknown_time"
2291
+ if mutation_id in self.mutations_outside_tree:
2292
+ mutation_class += " extra"
2293
+ mut_group = curr_svg_group.add(
2294
+ dwg.g(class_=mutation_class, transform=f"translate(0 {rnd(dy)})")
2295
+ )
2296
+ # A line from the mutation to the node below, normally hidden, but
2297
+ # revealable if we want to flag the path below a mutation
2298
+ mut_group.add(dwg.line(end=(0, -rnd(dy))))
2299
+ # Symbols
2300
+ symbol = mut_group.add(dwg.path(**self.mutation_attrs[mutation_id]))
2301
+ if mutation_id in self.mutation_titles:
2302
+ symbol.set_desc(title=self.mutation_titles[mutation_id])
2303
+ # Labels
2304
+ if u == left_child[parent_array[u]]:
2305
+ mut_label_class = "lft"
2306
+ transform = self.text_transform("left")
2307
+ else:
2308
+ mut_label_class = "rgt"
2309
+ transform = self.text_transform("right")
2310
+ add_class(self.mutation_label_attrs[mutation_id], mut_label_class)
2311
+ self.mutation_label_attrs[mutation_id]["transform"] = transform
2312
+ mut_group.add(dwg.text(**self.mutation_label_attrs[mutation_id]))
2313
+
2314
+ # Add node symbol + label (visually above the edge subtending this node)
2315
+ # -> symbols
2316
+ if tree.is_sample(u):
2317
+ symbol = curr_svg_group.add(dwg.rect(**self.node_attrs[u]))
2318
+ else:
2319
+ symbol = curr_svg_group.add(dwg.circle(**self.node_attrs[u]))
2320
+ multi_samples = None
2321
+ if (
2322
+ is_tip and tree.num_samples(u) > 1
2323
+ ): # Multi-sample tip => trapezium shape
2324
+ multi_samples = tree.num_samples(u)
2325
+ trapezium_attrs = self.node_attrs[u].copy()
2326
+ # Remove the shape-styling attributes
2327
+ for unwanted_attr in ("size", "insert", "center", "r"):
2328
+ trapezium_attrs.pop(unwanted_attr, None)
2329
+ trapezium_attrs["points"] = [ # add a trapezium shape below the symbol
2330
+ (self.symbol_size / 2, 0),
2331
+ (self.symbol_size, self.symbol_size),
2332
+ (-self.symbol_size, self.symbol_size),
2333
+ (-self.symbol_size / 2, 0),
2334
+ ]
2335
+ add_class(trapezium_attrs, "multi")
2336
+ curr_svg_group.add(dwg.polygon(**trapezium_attrs))
2337
+ if u in self.node_titles:
2338
+ symbol.set_desc(title=self.node_titles[u])
2339
+ # -> labels
2340
+ node_lab_attr = self.node_label_attrs[u]
2341
+ if is_tip and multi_samples is None:
2342
+ node_lab_attr["transform"] = self.text_transform("below")
2343
+ elif u in roots and self.min_root_branch_plot_length == 0:
2344
+ node_lab_attr["transform"] = self.text_transform("above")
2345
+ else:
2346
+ if multi_samples is not None:
2347
+ label = dwg.text(
2348
+ text=f"+{multi_samples}",
2349
+ transform=self.text_transform("below", dy=1),
2350
+ font_style="italic",
2351
+ class_="lab summary",
2352
+ )
2353
+ title = (
2354
+ f"A collapsed {'sample' if tree.is_sample(u) else 'non-sample'} "
2355
+ f"node with {multi_samples} descendant samples in this tree"
2356
+ )
2357
+ label.set_desc(title=title)
2358
+ curr_svg_group.add(label)
2359
+ if u == left_child[tree.parent(u)]:
2360
+ add_class(node_lab_attr, "lft")
2361
+ node_lab_attr["transform"] = self.text_transform("above_left")
2362
+ else:
2363
+ add_class(node_lab_attr, "rgt")
2364
+ node_lab_attr["transform"] = self.text_transform("above_right")
2365
+ curr_svg_group.add(dwg.text(**node_lab_attr))
2366
+
2367
+
2368
+ class TextTreeSequence:
2369
+ """
2370
+ Draw a tree sequence as horizontal line of trees.
2371
+ """
2372
+
2373
+ def __init__(
2374
+ self,
2375
+ ts,
2376
+ node_labels=None,
2377
+ use_ascii=False,
2378
+ time_label_format=None,
2379
+ position_label_format=None,
2380
+ order=None,
2381
+ ):
2382
+ self.ts = ts
2383
+
2384
+ time_label_format = "{:.2f}" if time_label_format is None else time_label_format
2385
+ tick_labels = ts.breakpoints(as_array=True)
2386
+ if position_label_format is None:
2387
+ position_scale_labels = create_tick_labels(tick_labels)
2388
+ else:
2389
+ position_scale_labels = [
2390
+ position_label_format.format(x) for x in tick_labels
2391
+ ]
2392
+
2393
+ time = ts.tables.nodes.time
2394
+ time_scale_labels = [
2395
+ time_label_format.format(time[u]) for u in range(ts.num_nodes)
2396
+ ]
2397
+
2398
+ trees = [
2399
+ VerticalTextTree(
2400
+ tree,
2401
+ max_time="ts",
2402
+ node_labels=node_labels,
2403
+ use_ascii=use_ascii,
2404
+ order=order,
2405
+ )
2406
+ for tree in self.ts.trees()
2407
+ ]
2408
+
2409
+ self.height = 1 + max(tree.height for tree in trees)
2410
+ self.width = sum(tree.width + 2 for tree in trees) - 1
2411
+ max_time_scale_label_len = max(map(len, time_scale_labels))
2412
+ self.width += 3 + max_time_scale_label_len + len(position_scale_labels[-1]) // 2
2413
+
2414
+ self.canvas = np.zeros((self.height, self.width), dtype=str)
2415
+ self.canvas[:] = " "
2416
+
2417
+ vertical_sep = "|" if use_ascii else "┊"
2418
+ x = 0
2419
+ time_position = trees[0].time_position
2420
+ for u, label in enumerate(map(to_np_unicode, time_scale_labels)):
2421
+ y = time_position[u]
2422
+ self.canvas[y, 0 : label.shape[0]] = label
2423
+ self.canvas[:, max_time_scale_label_len] = vertical_sep
2424
+ x = 2 + max_time_scale_label_len
2425
+
2426
+ for j, tree in enumerate(trees):
2427
+ pos_label = to_np_unicode(position_scale_labels[j])
2428
+ k = len(pos_label)
2429
+ label_x = max(x - k // 2 - 2, 0)
2430
+ self.canvas[-1, label_x : label_x + k] = pos_label
2431
+ h, w = tree.canvas.shape
2432
+ self.canvas[-h - 1 : -1, x : x + w - 1] = tree.canvas[:, :-1]
2433
+ x += w
2434
+ self.canvas[:, x] = vertical_sep
2435
+ x += 2
2436
+
2437
+ pos_label = to_np_unicode(position_scale_labels[-1])
2438
+ k = len(pos_label)
2439
+ label_x = max(x - k // 2 - 2, 0)
2440
+ self.canvas[-1, label_x : label_x + k] = pos_label
2441
+ self.canvas[:, -1] = "\n"
2442
+
2443
+ def __str__(self):
2444
+ return "".join(self.canvas.reshape(self.width * self.height))
2445
+
2446
+
2447
+ def to_np_unicode(string):
2448
+ """
2449
+ Converts the specified string to a numpy unicode array.
2450
+ """
2451
+ # TODO: what's the clean of doing this with numpy?
2452
+ # It really wants to create a zero-d Un array here
2453
+ # which breaks the assignment below and we end up
2454
+ # with n copies of the first char.
2455
+ n = len(string)
2456
+ np_string = np.zeros(n, dtype="U")
2457
+ for j in range(n):
2458
+ np_string[j] = string[j]
2459
+ return np_string
2460
+
2461
+
2462
+ def get_left_neighbour(tree, traversal_order):
2463
+ """
2464
+ Returns the left-most neighbour of each node in the tree according to the
2465
+ specified traversal order. The left neighbour is the closest node in terms
2466
+ of path distance to the left of a given node.
2467
+ """
2468
+ # The traversal order will define the order of children and roots.
2469
+ # Root order is defined by this traversal, and the roots are
2470
+ # the children of -1
2471
+ children = collections.defaultdict(list)
2472
+ for u in tree.nodes(order=traversal_order):
2473
+ children[tree.parent(u)].append(u)
2474
+
2475
+ left_neighbour = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)
2476
+
2477
+ def find_neighbours(u, neighbour):
2478
+ left_neighbour[u] = neighbour
2479
+ for v in children[u]:
2480
+ find_neighbours(v, neighbour)
2481
+ neighbour = v
2482
+
2483
+ # The children of -1 are the roots and the neighbour of all left-most
2484
+ # nodes in the tree is also -1 (NULL)
2485
+ find_neighbours(-1, -1)
2486
+
2487
+ return left_neighbour[:-1]
2488
+
2489
+
2490
+ def get_left_child(tree, postorder_nodes):
2491
+ """
2492
+ Returns the left-most child of each node in the tree according to the
2493
+ traversal order listed in postorder_nodes. If a node has no children or
2494
+ NULL is passed in, return NULL.
2495
+ """
2496
+ left_child = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)
2497
+ for u in postorder_nodes:
2498
+ parent = tree.parent(u)
2499
+ if parent != NULL and left_child[parent] == NULL:
2500
+ left_child[parent] = u
2501
+ return left_child
2502
+
2503
+
2504
+ def node_time_depth(tree, min_branch_length=None, max_time="tree"):
2505
+ """
2506
+ Returns a dictionary mapping nodes in the specified tree to their depth
2507
+ in the specified tree (from the root direction). If min_branch_len is
2508
+ provided, it specifies the minimum length of each branch. If not specified,
2509
+ default to 1.
2510
+ """
2511
+ if min_branch_length is None:
2512
+ min_branch_length = {u: 1 for u in range(tree.tree_sequence.num_nodes)}
2513
+ time_node_map = collections.defaultdict(list)
2514
+ current_depth = 0
2515
+ depth = {}
2516
+ # TODO this is basically the same code for the two cases. Refactor so that
2517
+ # we use the same code.
2518
+ if max_time == "tree":
2519
+ for u in tree.nodes():
2520
+ time_node_map[tree.time(u)].append(u)
2521
+ for t in sorted(time_node_map.keys()):
2522
+ for u in time_node_map[t]:
2523
+ for v in tree.children(u):
2524
+ current_depth = max(current_depth, depth[v] + min_branch_length[v])
2525
+ for u in time_node_map[t]:
2526
+ depth[u] = current_depth
2527
+ current_depth += 2
2528
+ for root in tree.roots:
2529
+ current_depth = max(current_depth, depth[root] + min_branch_length[root])
2530
+ else:
2531
+ assert max_time == "ts"
2532
+ ts = tree.tree_sequence
2533
+ for node in ts.nodes():
2534
+ time_node_map[node.time].append(node.id)
2535
+ node_edges = collections.defaultdict(list)
2536
+ for edge in ts.edges():
2537
+ node_edges[edge.parent].append(edge)
2538
+
2539
+ for t in sorted(time_node_map.keys()):
2540
+ for u in time_node_map[t]:
2541
+ for edge in node_edges[u]:
2542
+ v = edge.child
2543
+ current_depth = max(current_depth, depth[v] + min_branch_length[v])
2544
+ for u in time_node_map[t]:
2545
+ depth[u] = current_depth
2546
+ current_depth += 2
2547
+
2548
+ return depth, current_depth
2549
+
2550
+
2551
+ class TextTree:
2552
+ """
2553
+ Draws a reprentation of a tree using unicode drawing characters written
2554
+ to a 2D array.
2555
+ """
2556
+
2557
+ def __init__(
2558
+ self,
2559
+ tree,
2560
+ node_labels=None,
2561
+ max_time=None,
2562
+ min_time=None,
2563
+ use_ascii=False,
2564
+ orientation=None,
2565
+ order=None,
2566
+ ):
2567
+ self.tree = tree
2568
+ self.traversal_order = check_order(order)
2569
+ self.max_time = check_max_time(max_time, allow_numeric=False)
2570
+ self.min_time = check_min_time(min_time, allow_numeric=False)
2571
+ self.use_ascii = use_ascii
2572
+ self.orientation = check_orientation(orientation)
2573
+ self.horizontal_line_char = "━"
2574
+ self.vertical_line_char = "┃"
2575
+ if use_ascii:
2576
+ self.horizontal_line_char = "-"
2577
+ self.vertical_line_char = "|"
2578
+ # These are set below by the placement algorithms.
2579
+ self.width = None
2580
+ self.height = None
2581
+ self.canvas = None
2582
+ # Placement of nodes in the 2D space. Nodes are positioned in one
2583
+ # dimension based on traversal ordering and by their time in the
2584
+ # other dimension. These are mapped to x and y coordinates according
2585
+ # to the orientation.
2586
+ self.traversal_position = {} # Position of nodes in traversal space
2587
+ self.time_position = {}
2588
+ # Labels for nodes
2589
+ self.node_labels = {}
2590
+
2591
+ # Set the node labels
2592
+ for u in tree.nodes():
2593
+ if node_labels is None:
2594
+ # If we don't specify node_labels, default to node ID
2595
+ self.node_labels[u] = str(u)
2596
+ else:
2597
+ # If we do specify node_labels, default to an empty line
2598
+ self.node_labels[u] = self.default_node_label
2599
+ if node_labels is not None:
2600
+ for node, label in node_labels.items():
2601
+ self.node_labels[node] = label
2602
+
2603
+ self._assign_time_positions()
2604
+ self._assign_traversal_positions()
2605
+ self.canvas = np.zeros((self.height, self.width), dtype=str)
2606
+ self.canvas[:] = " "
2607
+ self._draw()
2608
+ self.canvas[:, -1] = "\n"
2609
+
2610
+ def __str__(self):
2611
+ return "".join(self.canvas.reshape(self.width * self.height))
2612
+
2613
+
2614
+ class VerticalTextTree(TextTree):
2615
+ """
2616
+ Text tree rendering where root nodes are at the top and time goes downwards
2617
+ into the present.
2618
+ """
2619
+
2620
+ @property
2621
+ def default_node_label(self):
2622
+ return self.vertical_line_char
2623
+
2624
+ def _assign_time_positions(self):
2625
+ tree = self.tree
2626
+ # TODO when we add mutations to the text tree we'll need to take it into
2627
+ # account here. Presumably we need to get the maximum number of mutations
2628
+ # per branch.
2629
+ self.time_position, total_depth = node_time_depth(tree, max_time=self.max_time)
2630
+ self.height = total_depth - 1
2631
+
2632
+ def _assign_traversal_positions(self):
2633
+ self.label_x = {}
2634
+ left_neighbour = get_left_neighbour(self.tree, self.traversal_order)
2635
+ x = 0
2636
+ for u in self.tree.nodes(order=self.traversal_order):
2637
+ label_size = len(self.node_labels[u])
2638
+ if self.tree.is_leaf(u):
2639
+ self.traversal_position[u] = x + label_size // 2
2640
+ self.label_x[u] = x
2641
+ x += label_size + 1
2642
+ else:
2643
+ coords = [self.traversal_position[c] for c in self.tree.children(u)]
2644
+ if len(coords) == 1:
2645
+ self.traversal_position[u] = coords[0]
2646
+ else:
2647
+ a = min(coords)
2648
+ b = max(coords)
2649
+ child_mid = int(round(a + (b - a) / 2))
2650
+ self.traversal_position[u] = child_mid
2651
+ self.label_x[u] = self.traversal_position[u] - label_size // 2
2652
+ neighbour_x = -1
2653
+ neighbour = left_neighbour[u]
2654
+ if neighbour != NULL:
2655
+ neighbour_x = self.traversal_position[neighbour]
2656
+ self.label_x[u] = max(neighbour_x + 1, self.label_x[u])
2657
+ x = max(x, self.label_x[u] + label_size + 1)
2658
+ assert self.label_x[u] >= 0
2659
+ self.width = x
2660
+
2661
+ def _draw(self):
2662
+ if self.use_ascii:
2663
+ left_child = "+"
2664
+ right_child = "+"
2665
+ mid_parent = "+"
2666
+ mid_parent_child = "+"
2667
+ mid_child = "+"
2668
+ elif self.orientation == TOP:
2669
+ left_child = "┏"
2670
+ right_child = "┓"
2671
+ mid_parent = "┻"
2672
+ mid_parent_child = "╋"
2673
+ mid_child = "┳"
2674
+ else:
2675
+ left_child = "┗"
2676
+ right_child = "┛"
2677
+ mid_parent = "┳"
2678
+ mid_parent_child = "╋"
2679
+ mid_child = "┻"
2680
+
2681
+ for u in self.tree.nodes():
2682
+ xu = self.traversal_position[u]
2683
+ yu = self.time_position[u]
2684
+ label = to_np_unicode(self.node_labels[u])
2685
+ label_len = label.shape[0]
2686
+ label_x = self.label_x[u]
2687
+ assert label_x >= 0
2688
+ self.canvas[yu, label_x : label_x + label_len] = label
2689
+ children = self.tree.children(u)
2690
+ if len(children) > 0:
2691
+ if len(children) == 1:
2692
+ yv = self.time_position[children[0]]
2693
+ self.canvas[yv:yu, xu] = self.vertical_line_char
2694
+ else:
2695
+ left = min(self.traversal_position[v] for v in children)
2696
+ right = max(self.traversal_position[v] for v in children)
2697
+ y = yu - 1
2698
+ self.canvas[y, left + 1 : right] = self.horizontal_line_char
2699
+ self.canvas[y, xu] = mid_parent
2700
+ for v in children:
2701
+ xv = self.traversal_position[v]
2702
+ yv = self.time_position[v]
2703
+ self.canvas[yv:yu, xv] = self.vertical_line_char
2704
+ mid_char = mid_parent_child if xv == xu else mid_child
2705
+ self.canvas[y, xv] = mid_char
2706
+ self.canvas[y, left] = left_child
2707
+ self.canvas[y, right] = right_child
2708
+ if self.orientation == TOP:
2709
+ self.canvas = np.flip(self.canvas, axis=0)
2710
+ # Reverse the time positions so that we can use them in the tree
2711
+ # sequence drawing as well.
2712
+ flipped_time_position = {
2713
+ u: self.height - y - 1 for u, y in self.time_position.items()
2714
+ }
2715
+ self.time_position = flipped_time_position
2716
+
2717
+
2718
+ class HorizontalTextTree(TextTree):
2719
+ """
2720
+ Text tree rendering where root nodes are at the left and time goes
2721
+ rightwards into the present.
2722
+ """
2723
+
2724
+ @property
2725
+ def default_node_label(self):
2726
+ return self.horizontal_line_char
2727
+
2728
+ def _assign_time_positions(self):
2729
+ # TODO when we add mutations to the text tree we'll need to take it into
2730
+ # account here. Presumably we need to get the maximum number of mutations
2731
+ # per branch.
2732
+ self.time_position, total_depth = node_time_depth(
2733
+ self.tree, {u: 1 + len(self.node_labels[u]) for u in self.tree.nodes()}
2734
+ )
2735
+ self.width = total_depth
2736
+
2737
+ def _assign_traversal_positions(self):
2738
+ y = 0
2739
+ for root in self.tree.roots:
2740
+ for u in self.tree.nodes(root, order=self.traversal_order):
2741
+ if self.tree.is_leaf(u):
2742
+ self.traversal_position[u] = y
2743
+ y += 2
2744
+ else:
2745
+ coords = [self.traversal_position[c] for c in self.tree.children(u)]
2746
+ if len(coords) == 1:
2747
+ self.traversal_position[u] = coords[0]
2748
+ else:
2749
+ a = min(coords)
2750
+ b = max(coords)
2751
+ child_mid = int(round(a + (b - a) / 2))
2752
+ self.traversal_position[u] = child_mid
2753
+ y += 1
2754
+ self.height = y - 2
2755
+
2756
+ def _draw(self):
2757
+ if self.use_ascii:
2758
+ top_across = "+"
2759
+ bot_across = "+"
2760
+ mid_parent = "+"
2761
+ mid_parent_child = "+"
2762
+ mid_child = "+"
2763
+ elif self.orientation == LEFT:
2764
+ top_across = "┏"
2765
+ bot_across = "┗"
2766
+ mid_parent = "┫"
2767
+ mid_parent_child = "╋"
2768
+ mid_child = "┣"
2769
+ else:
2770
+ top_across = "┓"
2771
+ bot_across = "┛"
2772
+ mid_parent = "┣"
2773
+ mid_parent_child = "╋"
2774
+ mid_child = "┫"
2775
+
2776
+ # Draw in root-right mode as the coordinates go in the expected direction.
2777
+ for u in self.tree.nodes():
2778
+ yu = self.traversal_position[u]
2779
+ xu = self.time_position[u]
2780
+ label = to_np_unicode(self.node_labels[u])
2781
+ if self.orientation == LEFT:
2782
+ # We flip the array at the end so need to reverse the label.
2783
+ label = label[::-1]
2784
+ label_len = label.shape[0]
2785
+ self.canvas[yu, xu : xu + label_len] = label
2786
+ children = self.tree.children(u)
2787
+ if len(children) > 0:
2788
+ if len(children) == 1:
2789
+ xv = self.time_position[children[0]]
2790
+ self.canvas[yu, xv:xu] = self.horizontal_line_char
2791
+ else:
2792
+ bot = min(self.traversal_position[v] for v in children)
2793
+ top = max(self.traversal_position[v] for v in children)
2794
+ x = xu - 1
2795
+ self.canvas[bot + 1 : top, x] = self.vertical_line_char
2796
+ self.canvas[yu, x] = mid_parent
2797
+ for v in children:
2798
+ yv = self.traversal_position[v]
2799
+ xv = self.time_position[v]
2800
+ self.canvas[yv, xv:x] = self.horizontal_line_char
2801
+ mid_char = mid_parent_child if yv == yu else mid_child
2802
+ self.canvas[yv, x] = mid_char
2803
+ self.canvas[bot, x] = top_across
2804
+ self.canvas[top, x] = bot_across
2805
+ if self.orientation == LEFT:
2806
+ self.canvas = np.flip(self.canvas, axis=1)
2807
+ # Move the padding to the left.
2808
+ self.canvas[:, :-1] = self.canvas[:, 1:]
2809
+ self.canvas[:, -1] = " "