arg-dashboard 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arg_dashboard/arg.py ADDED
@@ -0,0 +1,1019 @@
1
+
2
+ import sys
3
+ from math import exp, log
4
+ #import networkx.algorithms.non_randomness
5
+
6
+ import sys
7
+ import logging
8
+
9
+ log = logging.getLogger('arg_dashboard.arg')
10
+ # handler = logging.StreamHandler(sys.stderr)
11
+ # handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s'))
12
+ # log.addHandler(handler)
13
+ log.setLevel(logging.INFO)
14
+ import json
15
+
16
+ import random
17
+ # print("Setting random seed")
18
+ # random.seed(3)
19
+ # numpy.random.seed(3)
20
+
21
+ from random import random, shuffle, choices
22
+ from functools import partial
23
+ from copy import deepcopy, copy
24
+ # import networkx as nx
25
+ from functools import reduce
26
+ from numpy.random import exponential, sample
27
+ import numpy as np
28
+ from time import sleep
29
+ # import matplotlib.pyplot as plt
30
+ from collections import defaultdict
31
+
32
+ # from .arg_layout_json import compute_arg_xpos
33
+ # from .arg_layout_force import compute_arg_xpos
34
+ from .arg_layout_mindist import compute_arg_xpos
35
+
36
+ __all__ = [
37
+ 'Lineage',
38
+ 'Leaf',
39
+ 'get_arg_nodes',
40
+ 'arg2json',
41
+ 'json2arg',
42
+ ]
43
+
44
+ class Lineage(object):
45
+ """
46
+ Lineages are the edges of our graph
47
+ """
48
+ def __init__(self, lineageid=None, down=None, up=None, intervals=None):
49
+ self.lineageid = lineageid
50
+ self.down = down # node at bottom of edte
51
+ self.up = up # node at top of edge
52
+ self.intervals = intervals
53
+
54
+ def __hash__(self):
55
+ return self.lineageid
56
+
57
+ def __eq__(self, other):
58
+ return hasattr(other, 'lineageid') and self.lineageid == other.lineageid
59
+
60
+ def __repr__(self):
61
+ return f'{self.lineageid}:{self.intervals}'
62
+
63
+ def get_dict(self):
64
+ d = self.__dict__.copy()
65
+ if d['up']:
66
+ d['up'] = d['up'].nodeid
67
+ d['down'] = d['down'].nodeid
68
+ return d
69
+
70
+ def toJSON(self):
71
+ return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
72
+
73
+ def __deepcopy__(self, memo):
74
+ if self in memo:
75
+ return memo.get(self)
76
+ dup = type(self)(
77
+ lineageid=deepcopy(self.lineageid, memo),
78
+ # down=deepcopy(self.down, memo),
79
+ # up=deepcopy(self.up, memo),
80
+ intervals=deepcopy(self.intervals, memo)
81
+ )
82
+ memo[self] = dup
83
+ dup.down=deepcopy(self.down, memo)
84
+ dup.up=deepcopy(self.up, memo)
85
+ return dup
86
+
87
+ class Node():
88
+ """
89
+ Leaf, Coalscent and Recombination are the nodes of the graph
90
+ """
91
+ def __hash__(self):
92
+ return self.nodeid
93
+
94
+ def __eq__(self, other):
95
+ return hasattr(other, 'nodeid') and self.nodeid == other.nodeid
96
+
97
+ def __repr__(self):
98
+ return f'{self.nodeid}'
99
+
100
+ class Leaf(Node):
101
+
102
+ def __init__(self, nodeid=None, height=None, parent=None, intervals=[(0, 1)], xpos=None):
103
+ self.nodeid = nodeid
104
+ self.height = height
105
+ self.intervals = intervals
106
+ self.parent = parent
107
+ self.xpos = xpos
108
+
109
+ def get_dict(self):
110
+ d = self.__dict__.copy()
111
+ d['parent'] = d['parent'].lineageid
112
+ return d
113
+
114
+ def toJSON(self):
115
+ return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
116
+
117
+ def __deepcopy__(self, memo):
118
+ if self in memo:
119
+ return memo.get(self)
120
+ dup = type(self)(
121
+ nodeid=deepcopy(self.nodeid, memo),
122
+ height=deepcopy(self.height, memo),
123
+ # parent=deepcopy(self.parent, memo),
124
+ intervals=deepcopy(self.intervals, memo),
125
+ xpos=deepcopy(self.xpos, memo)
126
+ )
127
+ memo[self] = dup
128
+ dup.parent = deepcopy(self.parent, memo)
129
+ return dup
130
+
131
+ class Coalescent(Node):
132
+
133
+ def __init__(self, nodeid=None, height=None, children=None, parent=None, xpos=None):
134
+ self.nodeid = nodeid
135
+ self.height = height
136
+ self.children = children
137
+ self.parent = parent
138
+ self.xpos = xpos
139
+
140
+ def get_dict(self):
141
+ d = self.__dict__.copy()
142
+ d['children'] = [c.lineageid for c in d['children']]
143
+ d['parent'] = d['parent'].lineageid
144
+ return d
145
+
146
+ def toJSON(self):
147
+ return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
148
+
149
+ def __deepcopy__(self, memo):
150
+ if self in memo:
151
+ return memo.get(self)
152
+ dup = type(self)(
153
+ nodeid=deepcopy(self.nodeid, memo),
154
+ height=deepcopy(self.height, memo),
155
+ # children=deepcopy(self.children, memo),
156
+ # parent=deepcopy(self.parent, memo),
157
+ xpos=deepcopy(self.xpos, memo)
158
+ )
159
+ memo[self] = dup
160
+ dup.children=[deepcopy(c, memo) for c in self.children]
161
+ dup.parent=deepcopy(self.parent, memo)
162
+ return dup
163
+
164
+ class Recombination(Node):
165
+
166
+ def __init__(self, nodeid=None, height=None, child=None, recomb_point=None, left_parent=None, right_parent=None, xpos=None):
167
+ self.nodeid = nodeid
168
+ self.height = height
169
+ self.child = child
170
+ self.left_parent = left_parent
171
+ self.right_parent = right_parent
172
+ self.recomb_point = recomb_point
173
+ self.xpos = xpos
174
+
175
+ def get_dict(self):
176
+ d = self.__dict__.copy()
177
+ d['left_parent'] = d['left_parent'].lineageid
178
+ d['right_parent'] = d['right_parent'].lineageid
179
+ d['child'] = d['child'].lineageid
180
+ return d
181
+
182
+ def toJSON(self):
183
+ return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
184
+
185
+ def __deepcopy__(self, memo):
186
+ if self in memo:
187
+ return memo.get(self)
188
+ dup = type(self)(
189
+ nodeid=deepcopy(self.nodeid, memo),
190
+ height=deepcopy(self.height, memo),
191
+ # child=deepcopy(self.child, memo),
192
+ recomb_point=deepcopy(self.recomb_point, memo),
193
+ # left_parent=deepcopy(self.left_parent, memo),
194
+ # right_parent=deepcopy(self.right_parent, memo),
195
+ xpos=deepcopy(self.xpos, memo)
196
+ )
197
+ memo[self] = dup
198
+ dup.child=deepcopy(self.child, memo)
199
+ dup.left_parent=deepcopy(self.left_parent, memo)
200
+ dup.right_parent=deepcopy(self.right_parent, memo)
201
+ return dup
202
+
203
+ def flatten(list_of_tps):
204
+
205
+ return reduce(lambda ls, ival: ls + list(ival), list_of_tps, [])
206
+
207
+
208
+ def unflatten(list_of_endpoints):
209
+
210
+ return [ [list_of_endpoints[i], list_of_endpoints[i + 1]]
211
+ for i in range(0, len(list_of_endpoints) - 1, 2)]
212
+
213
+ def merge(query, annot, op):
214
+
215
+ a_endpoints = flatten(query)
216
+ b_endpoints = flatten(annot)
217
+
218
+ assert a_endpoints == sorted(a_endpoints), "not sorted or non-overlaping"
219
+ assert b_endpoints == sorted(b_endpoints), "not sorted or non-overlaping"
220
+
221
+
222
+ sentinel = max(a_endpoints[-1], b_endpoints[-1]) + 1
223
+ a_endpoints += [sentinel]
224
+ b_endpoints += [sentinel]
225
+
226
+ a_index = 0
227
+ b_index = 0
228
+
229
+ res = []
230
+
231
+ scan = min(a_endpoints[0], b_endpoints[0])
232
+ while scan < sentinel:
233
+ in_a = not ((scan < a_endpoints[a_index]) ^ (a_index % 2))
234
+ in_b = not ((scan < b_endpoints[b_index]) ^ (b_index % 2))
235
+ in_res = op(in_a, in_b)
236
+
237
+ if in_res ^ (len(res) % 2):
238
+ res += [scan]
239
+ if scan == a_endpoints[a_index]:
240
+ a_index += 1
241
+ if scan == b_endpoints[b_index]:
242
+ b_index += 1
243
+ scan = min(a_endpoints[a_index], b_endpoints[b_index])
244
+
245
+ return unflatten(res)
246
+
247
+ def interval_diff(a, b):
248
+ if not (a and b):
249
+ return a and a or b
250
+ return merge(a, b, lambda in_a, in_b: in_a and not in_b)
251
+
252
+ def interval_union(a, b):
253
+ if not (a and b):
254
+ return []
255
+ return merge(a, b, lambda in_a, in_b: in_a or in_b)
256
+
257
+ def interval_intersect(a, b):
258
+ if not (a and b):
259
+ return []
260
+ return merge(a, b, lambda in_a, in_b: in_a and in_b)
261
+
262
+ def interval_sum(intervals):
263
+ return sum(e - s for (s, e) in intervals)
264
+
265
+ def interval_span(intervals):
266
+ starts, ends = zip(*intervals)
267
+ return max(ends) - min(starts)
268
+
269
+ def interval_split(intervals, pos):
270
+ left, right = list(), list()
271
+ for s, e in intervals:
272
+ if pos >= e:
273
+ left.append((s, e))
274
+ elif pos >= s and pos < e:
275
+ left.append((s, pos))
276
+ right.append((pos, e))
277
+ else:
278
+ right.append((s, e))
279
+ return left, right
280
+
281
+ def interval_any_shared_borders(a, b):
282
+ log.info('shared borders')
283
+ flat = flatten(a) + flatten(b)
284
+ return len(flat) > len(set(flat))
285
+
286
+ def _get_neighbors_below(node):
287
+ """Get nodes connected below this node via child lineages."""
288
+ if type(node) is Coalescent:
289
+ return [c.down for c in node.children]
290
+ if type(node) is Recombination:
291
+ return [node.child.down]
292
+ return []
293
+
294
+ def _get_neighbors_above(node):
295
+ """Get nodes connected above this node via parent lineages."""
296
+ result = []
297
+ if type(node) is Coalescent:
298
+ if node.parent and node.parent.up:
299
+ result.append(node.parent.up)
300
+ elif type(node) is Recombination:
301
+ if node.left_parent and node.left_parent.up:
302
+ result.append(node.left_parent.up)
303
+ if node.right_parent and node.right_parent.up:
304
+ result.append(node.right_parent.up)
305
+ elif type(node) is Leaf:
306
+ if node.parent and node.parent.up:
307
+ result.append(node.parent.up)
308
+ return result
309
+
310
+ def _dfs_leaf_order(node, visited=None):
311
+ """DFS from root to get a leaf ordering that groups related leaves.
312
+ Randomizes child visit order at coalescent nodes for multi-start."""
313
+ if visited is None:
314
+ visited = set()
315
+ if node in visited:
316
+ return []
317
+ visited.add(node)
318
+ if type(node) is Leaf:
319
+ return [node]
320
+ leaves = []
321
+ if type(node) is Coalescent:
322
+ children = list(node.children)
323
+ shuffle(children)
324
+ for child in children:
325
+ leaves.extend(_dfs_leaf_order(child.down, visited))
326
+ elif type(node) is Recombination:
327
+ leaves.extend(_dfs_leaf_order(node.child.down, visited))
328
+ return leaves
329
+
330
+ ####################################################
331
+
332
+ def segments_crossing(line1, line2):
333
+ xdiff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0])
334
+ ydiff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1])
335
+
336
+ def det(a, b):
337
+ return a[0] * b[1] - a[1] * b[0]
338
+
339
+ div = det(xdiff, ydiff)
340
+ if div == 0:
341
+ return False
342
+ # raise Exception('lines do not intersect')
343
+
344
+ d = (det(*line1), det(*line2))
345
+ x = det(d, xdiff) / div
346
+ y = det(d, ydiff) / div
347
+
348
+ # return x, y
349
+
350
+ return \
351
+ min(line1[0][0], line1[1][0]) < x < max(line1[0][0], line1[1][0]) and \
352
+ min(line2[0][0], line2[1][0]) < x < max(line2[0][0], line2[1][0]) and \
353
+ min(line1[0][1], line1[1][1]) < y < max(line1[0][1], line1[1][1]) and \
354
+ min(line2[0][1], line2[1][1]) < y < max(line2[0][1], line2[1][1])
355
+
356
+
357
+ def crossing(a, b):
358
+ return segments_crossing(
359
+ ((a.down.xpos, a.down.height), (a.up.xpos, a.up.height)),
360
+ ((b.down.xpos, b.down.height), (b.up.xpos, b.up.height))
361
+ )
362
+
363
+ def get_all_crossing_pairs(lineages):
364
+ crossing_pairs = []
365
+ for i in range(len(lineages)):
366
+ for j in range(i+1, len(lineages)):
367
+ if crossing(lineages[i], lineages[j]):
368
+ crossing_pairs.append([i, j])
369
+ return crossing_pairs
370
+
371
+ def reduce_crossovers(nodes):
372
+ lineages = get_parent_lineages(nodes)
373
+ lineages = [x for x in lineages if x.up is not None]
374
+ crossing_pairs = get_all_crossing_pairs(lineages)
375
+
376
+ for i, j in crossing_pairs:
377
+ # swap downs
378
+ lineages[i].down.xpos, lineages[j].down.xpos = lineages[j].down.xpos, lineages[i].down.xpos
379
+ new_crossing_pairs = get_all_crossing_pairs(lineages)
380
+ if not len(new_crossing_pairs) < len(crossing_pairs):
381
+ # swap downs back
382
+ lineages[i].down.xpos, lineages[j].down.xpos = lineages[j].down.xpos, lineages[i].down.xpos
383
+ else:
384
+ crossing_pairs = new_crossing_pairs
385
+
386
+ # swap ups
387
+ lineages[i].up.xpos, lineages[j].up.xpos = lineages[j].up.xpos, lineages[i].up.xpos
388
+ new_crossing_pairs = get_all_crossing_pairs(lineages)
389
+ if not len(new_crossing_pairs) < len(crossing_pairs):
390
+ # swap ups back
391
+ lineages[i].up.xpos, lineages[j].up.xpos = lineages[j].up.xpos, lineages[i].up.xpos
392
+ else:
393
+ crossing_pairs = new_crossing_pairs
394
+
395
+
396
+ leaf_idx = [i for i, n in enumerate(nodes) if type(n) is Leaf]
397
+
398
+ if type(lineages[i].down) is Leaf and type(lineages[j].down) is not Leaf:
399
+ # try to find new position for leaf i
400
+
401
+ for k in leaf_idx:
402
+ # swap
403
+ lineages[i].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[i].down.xpos
404
+
405
+ new_crossing_pairs = get_all_crossing_pairs(lineages)
406
+ if not len(new_crossing_pairs) < len(crossing_pairs):
407
+ # swap back
408
+ lineages[i].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[i].down.xpos
409
+ else:
410
+ break
411
+
412
+
413
+ if type(lineages[i].down) is Leaf and type(lineages[j].down) is not Leaf:
414
+ # try to find new position for leaf j
415
+ for k in leaf_idx:
416
+ # swap
417
+ lineages[j].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[j].down.xpos
418
+
419
+ new_crossing_pairs = get_all_crossing_pairs(lineages)
420
+ if not len(new_crossing_pairs) < len(crossing_pairs):
421
+ # swap back
422
+ lineages[j].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[j].down.xpos
423
+ else:
424
+ break
425
+
426
+
427
+ # TODO: if cross over involves an internal and an external branch, see if
428
+ # you can move the leaf to a new position that removes the crossover
429
+
430
+
431
+
432
+ # leaf_idx = [i for i, n in enumerate(nodes) if type(n) is Leaf]
433
+ # suffled_leaf_idx = leaf_idx[:]
434
+ # shuffle(suffled_leaf_idx)
435
+ # for i, j in sorted(zip(leaf_idx, suffled_leaf_idx)):
436
+ # print(i, j)
437
+ # nodes[i].xpos = nodes[j].xpos
438
+
439
+ # new_crossing_pairs = get_all_crossing_pairs(lineages)
440
+ # if not len(new_crossing_pairs) < len(crossing_pairs):
441
+ # for i, j in sorted(zip(leaf_idx, suffled_leaf_idx), reverse=True):
442
+ # nodes[i].xpos = nodes[j].xpos
443
+
444
+
445
+
446
+
447
+ def branch_length(lineage):
448
+ return ((lineage.down.xpos - lineage.up.xpos)**2 + (lineage.down.height - lineage.up.height)**2)**0.5
449
+
450
+ def reduce_total_branch_length(nodes, shift):
451
+ lineages = get_parent_lineages(nodes)
452
+ lineages = [x for x in lineages if x.up is not None]
453
+
454
+ cur_tot = sum(branch_length(l) for l in lineages)
455
+ for i in range(len(lineages)):
456
+ orig_xpos = lineages[i].up.xpos
457
+ lineages[i].up.xpos += shift
458
+ tot = sum(branch_length(l) for l in lineages)
459
+ if tot < cur_tot:
460
+ cur_tot = tot
461
+ else:
462
+ lineages[i].up.xpos = orig_xpos
463
+ lineages[i].up.xpos -= shift
464
+ tot = sum(branch_length(l) for l in lineages)
465
+ if tot < cur_tot:
466
+ cur_tot = tot
467
+ else:
468
+ lineages[i].up.xpos = orig_xpos
469
+
470
+ def redistribute_leaves(nodes):
471
+
472
+ leaves = [n for n in nodes if type(n) is Leaf]
473
+ # min_x = min(n.xpos for n in leaves)
474
+ # max_x = max(n.xpos for n in leaves)
475
+ leaves = sorted(leaves, key=lambda x: x.xpos)
476
+ min_x = 0
477
+ max_x = 1
478
+ for xpos, leaf in zip(np.linspace(min_x, max_x, len(leaves)), leaves):
479
+ leaf.xpos = xpos
480
+
481
+ ####################################################
482
+
483
+ def _bottom_up_barycenter(nodes):
484
+ """Single bottom-up pass: each node at mean of its children's x."""
485
+ non_leaves = sorted(
486
+ [n for n in nodes if type(n) is not Leaf],
487
+ key=lambda n: n.height
488
+ )
489
+ for node in non_leaves:
490
+ below = _get_neighbors_below(node)
491
+ placed = [n for n in below if n.xpos is not None]
492
+ if placed:
493
+ node.xpos = sum(n.xpos for n in placed) / len(placed)
494
+ else:
495
+ node.xpos = 0.5
496
+
497
+ def _count_crossings(nodes):
498
+ """Count total edge crossings in current layout."""
499
+ lineages = get_parent_lineages(nodes)
500
+ lineages = [x for x in lineages if x.up is not None]
501
+ return len(get_all_crossing_pairs(lineages))
502
+
503
+ def add_node_x_positions(nodes, n_restarts=10):
504
+ """
505
+ Position nodes to minimize crossings using:
506
+ 1. DFS leaf ordering (groups related leaves)
507
+ 2. Bottom-up barycenter (single pass, no funnel)
508
+ 3. Rescale + crossing reduction post-processing
509
+
510
+ Multi-start: tries n_restarts random DFS orderings, keeps best.
511
+ """
512
+ best_crossings = float('inf')
513
+ best_positions = None
514
+
515
+ for _ in range(n_restarts):
516
+ # Step 1: DFS leaf ordering (randomized child order)
517
+ leaf_order = _dfs_leaf_order(nodes[-1])
518
+ seen = set(id(l) for l in leaf_order)
519
+ for n in nodes:
520
+ if type(n) is Leaf and id(n) not in seen:
521
+ leaf_order.append(n)
522
+
523
+ # Step 2: Fix leaves evenly spaced
524
+ n_leaves = len(leaf_order)
525
+ for i, leaf in enumerate(leaf_order):
526
+ leaf.xpos = i / max(n_leaves - 1, 1)
527
+
528
+ # Step 3: Bottom-up barycenter
529
+ _bottom_up_barycenter(nodes)
530
+
531
+ # Step 4: Rescale to [0, 1]
532
+ rescale_positions(nodes)
533
+
534
+ # Step 5: Re-even leaves, reduce crossings
535
+ redistribute_leaves(nodes)
536
+ for _ in range(5):
537
+ reduce_crossovers(nodes)
538
+
539
+ # Evaluate
540
+ crossings = _count_crossings(nodes)
541
+ if crossings < best_crossings:
542
+ best_crossings = crossings
543
+ best_positions = {n.nodeid: n.xpos for n in nodes}
544
+
545
+ if best_crossings == 0:
546
+ break
547
+
548
+ # Restore best positions
549
+ if best_positions:
550
+ for n in nodes:
551
+ n.xpos = best_positions[n.nodeid]
552
+
553
+ # Final refinement
554
+ shift_list = [0.5, 0.4, 0.3, 0.2, 0.1, 0.05, 0.02]
555
+ for shift in shift_list:
556
+ reduce_total_branch_length(nodes, shift)
557
+ rescale_positions(nodes)
558
+ for _ in range(5):
559
+ reduce_crossovers(nodes)
560
+ redistribute_leaves(nodes)
561
+
562
+ def get_arg_nodes(n=5, N=10000, r=1e-8, L=5e3, simulation="arg"):
563
+ """
564
+ Simulates an ARG
565
+ """
566
+ # because we use the sequence interval from 0 to 1
567
+ r = r * L
568
+
569
+ # print('\nSIM:')
570
+
571
+
572
+ assert simulation in ["arg", "smcprime", "smc"]
573
+
574
+ nodes = list()
575
+ live = list() # live lineages
576
+
577
+ for i in range(n):
578
+ leaf = Leaf(i, height=0)
579
+ nodes.append(leaf)
580
+ lin = Lineage(lineageid=i, down=leaf, intervals=[(0, 1)])
581
+ leaf.parent = lin
582
+ live.append(lin)
583
+ last_node = i # max node number used so far
584
+ last_lineage = i
585
+
586
+ while len(live) > 1:
587
+ shuffle(live)
588
+
589
+ coal_prob = (len(live) * len(live)-1) / 2 / (2*N)
590
+ tot_ancestral = sum(interval_sum(x.intervals) for x in live)
591
+ rec_prob = r * tot_ancestral
592
+ wating_time = exponential(1 / (coal_prob + rec_prob))
593
+ height = wating_time + nodes[-1].height
594
+ if random() < coal_prob / (coal_prob + rec_prob):
595
+ # coalescence
596
+
597
+ # if full arg:
598
+ a, b = 0, 1
599
+
600
+ if simulation == "smcprime":
601
+ # intervals must overlap or be adjacent:
602
+ while not (interval_any_shared_borders(live[a].intervals, live[b].intervals) or \
603
+ interval_intersect(live[a].intervals, live[b].intervals)):
604
+ shuffle(live)
605
+ elif simulation == "smc":
606
+ # intervals must overlap:
607
+ while not interval_intersect(live[a].intervals, live[b].intervals):
608
+ shuffle(live)
609
+
610
+ lin_a, lin_b = live.pop(0), live.pop(0)
611
+
612
+ intervals = interval_union(lin_a.intervals, lin_b.intervals)
613
+
614
+ # new node
615
+ if lin_a.down is lin_b.down:
616
+ # make sure diamond is symmetric
617
+ children=[lin_a.down.left_parent, lin_b.down.right_parent]
618
+ else:
619
+ children=[lin_a, lin_b]
620
+ node_c = Coalescent(nodeid=last_node+1, height=height,
621
+ children=children)
622
+
623
+ last_node += 1
624
+ nodes.append(node_c)
625
+
626
+ # add node to top of coalescing lineages
627
+ lin_a.up = node_c
628
+ lin_b.up = node_c
629
+
630
+ # new lineage
631
+ lin_c = Lineage(lineageid=last_lineage+1, down=node_c, intervals=intervals) # fixme intervals
632
+ last_lineage += 1
633
+ live.append(lin_c)
634
+ node_c.parent = lin_c
635
+ else:
636
+ # recombination
637
+
638
+ # rec_lin = choices(live, weights=[interval_sum(x.intervals) for x in live], k=1)[0]
639
+ rec_lin = choices(live, weights=[interval_span(x.intervals) for x in live], k=1)[0]
640
+ live.remove(rec_lin)
641
+
642
+ # total ancestral material
643
+ total_anc = interval_sum(rec_lin.intervals)
644
+
645
+ # recombination point in ancestral material
646
+ recomb_point_anc = random() * total_anc
647
+
648
+ # recombination point in full sequence
649
+ cum = 0
650
+ for s, e in rec_lin.intervals:
651
+ cum += e - s
652
+ if cum > recomb_point_anc:
653
+ recomb_point = e - (cum - recomb_point_anc)
654
+ break
655
+
656
+ # recombination node
657
+ rec_node = Recombination(nodeid=last_node+1, height=height,
658
+ child=rec_lin, recomb_point=recomb_point)
659
+ last_node += 1
660
+ nodes.append(rec_node)
661
+
662
+ # two new lineages both refering back to recombination node
663
+ intervals_a, intervals_b = interval_split(rec_lin.intervals, recomb_point)
664
+ assert interval_sum(intervals_a) and interval_sum(intervals_b)
665
+ assert sum(e-s for s, e in intervals_a) + sum(e-s for s, e in intervals_b) == sum(e-s for s, e in rec_lin.intervals)
666
+ lin_a = Lineage(lineageid=last_lineage+1, down=rec_node, intervals=intervals_a)
667
+ last_lineage += 1
668
+ lin_b = Lineage(lineageid=last_lineage+1, down=rec_node, intervals=intervals_b)
669
+ last_lineage += 1
670
+ live.append(lin_a)
671
+ live.append(lin_b)
672
+
673
+ # add parents of node
674
+ rec_node.left_parent = lin_a
675
+ rec_node.right_parent = lin_b
676
+
677
+ # add parent for child
678
+ rec_lin.up = rec_node
679
+
680
+ add_node_x_positions(nodes)
681
+
682
+ return nodes
683
+
684
+ def _rescale_layout(pos, scale=1):
685
+ # rescale to (0,pscale) in all axes
686
+
687
+ # shift origin to (0,0)
688
+ lim = 0 # max coordinate for all axes
689
+ for i in range(pos.shape[1]):
690
+ pos[:, i] -= pos[:, i].min()
691
+ lim = max(pos[:, i].max(), lim)
692
+ # rescale to (0,scale) in all directions, preserves aspect
693
+ for i in range(pos.shape[1]):
694
+ pos[:, i] *= scale / lim
695
+ return pos
696
+
697
+ def get_positions(nodes):
698
+ """
699
+ Gets list of x,y positions for nodes
700
+ """
701
+ positions = list()
702
+ for node in nodes:
703
+ positions.append((node.xpos, node.height))
704
+ return positions
705
+
706
+ def get_breakpoints(nodes):
707
+ """
708
+ Gets list of recombination break points
709
+ """
710
+ # get list of break points
711
+ breakpoints = list()
712
+ for n in nodes:
713
+ if type(n) is Recombination:
714
+ # if type(n) is Recombination and all(p.up in nodes for p in n.parents):
715
+ breakpoints.append(n.recomb_point)
716
+ return sorted(breakpoints)
717
+
718
+ def get_parent_lineages(nodes, root=True):
719
+ """
720
+ Get list of parent lineages of nodes ordered by id
721
+ """
722
+
723
+ lineages = list()
724
+ for node in nodes:
725
+ if type(node) is Coalescent:
726
+ lineages.append(node.parent)
727
+ if type(node) is Recombination:
728
+ lineages.append(node.left_parent)
729
+ lineages.append(node.right_parent)
730
+ if type(node) is Leaf:
731
+ lineages.append(node.parent)
732
+
733
+ # get unique and sort them so lineageid matches index
734
+ lineages = list(set(lineages))
735
+ lineages.sort(key=lambda x: x.lineageid)
736
+
737
+ if not root:
738
+ # remove dangling root lineage
739
+ lineages = lineages[:-1]
740
+
741
+ return lineages
742
+
743
+ def get_child_lineages(nodes):
744
+ """
745
+ Get list of child lineages of nodes ordered by id
746
+ """
747
+
748
+ lineages = list()
749
+ for node in nodes:
750
+ if type(node) is Coalescent:
751
+ lineages.extend(node.children)
752
+ if type(node) is Recombination:
753
+ lineages.append(node.child)
754
+
755
+ # get unique and sort them so lineageid matches index
756
+ lineages = list(set(lineages))
757
+ lineages.sort(key=lambda x: x.lineageid)
758
+
759
+ return lineages
760
+
761
+ # def traverse_marginal(node, interval):
762
+ # """
763
+ # Recursive function for getting marginal tree/ARG
764
+ # """
765
+ # node = deepcopy(node) # TODO: remove if input is a cloned arg.
766
+ # tree_nodes = set()
767
+ # if type(node) is Leaf:
768
+ # tree_nodes.add(node)
769
+ # if type(node) is Recombination:
770
+ # if interval_intersect([interval], node.child.intervals):
771
+ # tree_nodes.add(node)
772
+ # tree_nodes.update(traverse_marginal(node.child.down, interval))
773
+ # elif type(node) is Coalescent:
774
+ # if node.parent is None or interval_intersect([interval], node.parent.intervals):
775
+ # tree_nodes.add(node)
776
+ # del_child = None
777
+ # for i, child in enumerate(node.children):
778
+ # if interval_intersect([interval], child.intervals):
779
+ # tree_nodes.update(traverse_marginal(child.down, interval))
780
+ # else:
781
+ # del_child = i
782
+ # if del_child is not None:
783
+ # del node.children[del_child]
784
+ # return tree_nodes
785
+
786
+ def _traverse_marginal(node, interval):
787
+ """
788
+ Recursive function for getting marginal tree/ARG
789
+ """
790
+ tree_nodes = set()
791
+ if type(node) is Leaf:
792
+ tree_nodes.add(node)
793
+ if type(node) is Recombination:
794
+ if interval_intersect([interval], node.child.intervals):
795
+ tree_nodes.add(node)
796
+ tree_nodes.update(_traverse_marginal(node.child.down, interval))
797
+ elif type(node) is Coalescent:
798
+ if node.parent is None or interval_intersect([interval], node.parent.intervals):
799
+ tree_nodes.add(node)
800
+ del_child = None
801
+ for i, child in enumerate(node.children):
802
+ if interval_intersect([interval], child.intervals):
803
+ tree_nodes.update(_traverse_marginal(child.down, interval))
804
+ else:
805
+ del_child = i
806
+ if del_child is not None:
807
+ del node.children[del_child]
808
+ return tree_nodes
809
+
810
+ def traverse_marginal(node, interval):
811
+ clone = deepcopy(node) # TODO: remove if input is a cloned arg.
812
+ return _traverse_marginal(clone, interval)
813
+
814
+
815
+ def remove_dangling_root(tree_nodes):
816
+ """
817
+ Remove the nodes of a marginal tree or ARG
818
+ from root to first coalescence
819
+ """
820
+ for_del = list()
821
+ for i in range(len(tree_nodes)-1, -1, -1):
822
+ if type(tree_nodes[i]) is Coalescent and len(tree_nodes[i].children) == 2:
823
+ break
824
+ for_del.append(i)
825
+ for i in for_del:
826
+ del tree_nodes[i]
827
+
828
+ def marginal_arg(nodes, interval, strip_dangling_root=True):
829
+ """
830
+ Gets the marginal ARG given a sequene interval
831
+ """
832
+
833
+ # TODO:
834
+ # nodes = clone_arg(nodes)
835
+
836
+ # get nodes for marginal
837
+ marg_nodes = traverse_marginal(nodes[-1], list(interval))
838
+
839
+ # set to list
840
+ marg_nodes = list(marg_nodes)
841
+
842
+ # sort on height
843
+ marg_nodes.sort(key=lambda x: x.height)
844
+ # prune top path above last coalescence
845
+ if strip_dangling_root:
846
+ remove_dangling_root(marg_nodes)
847
+
848
+ return marg_nodes
849
+
850
+ def marginal_trees(nodes, interval, strip_dangling_root=False):
851
+ """
852
+ Gets list of marginal trees
853
+ """
854
+ tree_list = list()
855
+ breakpoints = get_breakpoints(nodes)
856
+
857
+ borders = [0] + breakpoints + [1]
858
+
859
+ # for interval in zip(borders[:-1], borders[1:]):
860
+ # marg_nodes = marginal_arg(nodes, interval)
861
+ # tree_list.append(marg_nodes)
862
+ # return tree_list
863
+
864
+ interval_list = list()
865
+ for interv in zip(borders[:-1], borders[1:]):
866
+ marg_nodes = marginal_arg(nodes, interv, strip_dangling_root=strip_dangling_root)
867
+ # TODO: make this correct:...
868
+ if marg_nodes and interv[1] > interval[0] and interval[1] > interv[0]:
869
+ tree_list.append(marg_nodes)
870
+ interval_list.append(interv)
871
+ return tree_list, interval_list
872
+
873
+ # def draw_graph(nodes):
874
+ # """
875
+ # Draws graph using matplotlib
876
+ # """
877
+ # # make a graph and positions list
878
+ # positions = list()
879
+ # arg = nx.Graph()
880
+ # for node in nodes:
881
+ # arg.add_node(node.nodeid)
882
+ # positions.append((node.xpos, node.height))
883
+ # if isinstance(node, Recombination):
884
+ # arg.add_edge(node.child.down.nodeid, node.nodeid)
885
+ # elif isinstance(node, Coalescent):
886
+ # for child in node.children:
887
+ # arg.add_edge(child.down.nodeid, node.nodeid)
888
+
889
+ # positions = np.array(positions)
890
+ # positions = dict(zip(arg.nodes(), positions))
891
+
892
+ # #pos = nx.spring_layout(arg)
893
+ # nx.draw(arg, positions, alpha=0.5, node_size=200, with_labels=True)
894
+ # #with_labels=True,
895
+ # #connectionstyle='arc3, rad = 0.1',
896
+ # #arrowstyle='-')
897
+
898
+ # # G = nx.DiGraph() #or G = nx.MultiDiGraph()
899
+ # # G.add_node('A')
900
+ # # G.add_node('B')
901
+ # # G.add_edge('A', 'B', length = 2)
902
+ # # G.add_edge('B', 'A', length = 3)
903
+ # # pos = nx.spring_layout(G)
904
+ # # nx.draw(G, pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
905
+ # # edge_labels=dict([((u,v,),d['length'])
906
+ # # for u,v,d in G.edges(data=True)])
907
+
908
+ # plt.show()
909
+
910
+ def rescale_positions(nodes):
911
+ """
912
+ Rescales xpos and heights to 0,1 range
913
+ """
914
+ max_height = max(n.height for n in nodes)
915
+ max_xpos = max(n.xpos for n in nodes)
916
+ min_xpos = min(n.xpos for n in nodes)
917
+ for node in nodes:
918
+ node.xpos = (node.xpos - min_xpos) / (max_xpos-min_xpos)
919
+ node.height = node.height / max_height
920
+
921
+ def arg2json(nodes):
922
+
923
+ lineages = get_parent_lineages(nodes)
924
+
925
+ json_data = dict(Coalescent=[], Recombination=[], Leaf=[],
926
+ Lineage=[l.get_dict() for l in lineages])
927
+ for node in nodes:
928
+ json_data[node.__class__.__name__].append(node.get_dict())
929
+
930
+ json_data = compute_arg_xpos(json_data)
931
+
932
+ return json.dumps(json_data, indent=4)
933
+
934
+ def json2arg(json_str):
935
+
936
+ nodes = list()
937
+
938
+ data = json.loads(json_str)
939
+
940
+ # make lineages (with indexes instead of node refs)
941
+ lineages = [Lineage(**data) for data in data['Lineage']]
942
+
943
+ for node_data in data['Coalescent']:
944
+ # make the node
945
+ node = Coalescent(**node_data)
946
+ # populate the parent and children with actual lineages
947
+ node.parent = lineages[node.parent]
948
+ node.children = [lineages[i] for i in node.children]
949
+ nodes.append(node)
950
+
951
+ for node_data in data['Recombination']:
952
+ # make the node
953
+ node = Recombination(**node_data)
954
+ # populate the parents and child with actual lineages
955
+ node.left_parent = lineages[node.left_parent]
956
+ node.right_parent = lineages[node.right_parent]
957
+ node.child = lineages[node.child]
958
+ nodes.append(node)
959
+
960
+ for node_data in data['Leaf']:
961
+ # make the node
962
+ node = Leaf(**node_data)
963
+ # populate the parent with actual lineages
964
+ node.parent = lineages[node.parent]
965
+ nodes.append(node)
966
+
967
+ nodes.sort(key=lambda x: x.nodeid)
968
+
969
+ # populate up and down with the nodes
970
+ for lineage in lineages:
971
+ lineage.down = nodes[lineage.down]
972
+ if lineage.up is not None:
973
+ lineage.up = nodes[lineage.up]
974
+
975
+ return nodes
976
+
977
+ if __name__ == '__main__':
978
+
979
+ # get arg and add positions
980
+ nodes = get_arg_nodes()
981
+
982
+ print(nodes)
983
+
984
+ print(deepcopy(nodes))
985
+
986
+
987
+ json_str = arg2json(nodes)
988
+ print(json_str)
989
+ retrieved_nodes = json2arg(json_str)
990
+
991
+ print(nodes)
992
+ print(retrieved_nodes)
993
+ print(nodes == retrieved_nodes)
994
+
995
+
996
+
997
+ # get breakpoints
998
+ breakpoints = get_breakpoints(nodes)
999
+ print(breakpoints)
1000
+
1001
+ sys.exit()
1002
+
1003
+ # get marginal trees
1004
+ trees = marginal_trees(nodes)
1005
+ print(trees)
1006
+
1007
+ # marginal arg for some consequtive intervals
1008
+ marg_arg = marginal_arg(nodes, [0, breakpoints[1]])
1009
+
1010
+
1011
+ # # draw graphs for testing
1012
+ # draw_graph(nodes)
1013
+ # draw_graph(marg_arg)
1014
+
1015
+ # draw_graph(nodes)
1016
+ # for tree in marginal_trees(nodes):
1017
+ # print([n.xpos for n in tree])
1018
+ # draw_graph(tree)
1019
+