arg-dashboard 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arg_dashboard/__init__.py +934 -0
- arg_dashboard/arg.py +1019 -0
- arg_dashboard/arg_layout_force.py +365 -0
- arg_dashboard/arg_layout_json.py +269 -0
- arg_dashboard/arg_layout_mindist.py +412 -0
- arg_dashboard/assets/dashboard.css +139 -0
- arg_dashboard/assets/images/arg.png +0 -0
- arg_dashboard/assets/images/placeholder286x180.png +0 -0
- arg_dashboard/assets/tabs.css +34 -0
- arg_dashboard/index.py +6 -0
- arg_dashboard-0.1.19.dist-info/METADATA +88 -0
- arg_dashboard-0.1.19.dist-info/RECORD +16 -0
- arg_dashboard-0.1.19.dist-info/WHEEL +5 -0
- arg_dashboard-0.1.19.dist-info/entry_points.txt +2 -0
- arg_dashboard-0.1.19.dist-info/licenses/LICENSE +674 -0
- arg_dashboard-0.1.19.dist-info/top_level.txt +1 -0
arg_dashboard/arg.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
1
|
+
|
|
2
|
+
import sys
|
|
3
|
+
from math import exp, log
|
|
4
|
+
#import networkx.algorithms.non_randomness
|
|
5
|
+
|
|
6
|
+
import sys
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger('arg_dashboard.arg')
|
|
10
|
+
# handler = logging.StreamHandler(sys.stderr)
|
|
11
|
+
# handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s'))
|
|
12
|
+
# log.addHandler(handler)
|
|
13
|
+
log.setLevel(logging.INFO)
|
|
14
|
+
import json
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
# print("Setting random seed")
|
|
18
|
+
# random.seed(3)
|
|
19
|
+
# numpy.random.seed(3)
|
|
20
|
+
|
|
21
|
+
from random import random, shuffle, choices
|
|
22
|
+
from functools import partial
|
|
23
|
+
from copy import deepcopy, copy
|
|
24
|
+
# import networkx as nx
|
|
25
|
+
from functools import reduce
|
|
26
|
+
from numpy.random import exponential, sample
|
|
27
|
+
import numpy as np
|
|
28
|
+
from time import sleep
|
|
29
|
+
# import matplotlib.pyplot as plt
|
|
30
|
+
from collections import defaultdict
|
|
31
|
+
|
|
32
|
+
# from .arg_layout_json import compute_arg_xpos
|
|
33
|
+
# from .arg_layout_force import compute_arg_xpos
|
|
34
|
+
from .arg_layout_mindist import compute_arg_xpos
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
'Lineage',
|
|
38
|
+
'Leaf',
|
|
39
|
+
'get_arg_nodes',
|
|
40
|
+
'arg2json',
|
|
41
|
+
'json2arg',
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
class Lineage(object):
|
|
45
|
+
"""
|
|
46
|
+
Lineages are the edges of our graph
|
|
47
|
+
"""
|
|
48
|
+
def __init__(self, lineageid=None, down=None, up=None, intervals=None):
|
|
49
|
+
self.lineageid = lineageid
|
|
50
|
+
self.down = down # node at bottom of edte
|
|
51
|
+
self.up = up # node at top of edge
|
|
52
|
+
self.intervals = intervals
|
|
53
|
+
|
|
54
|
+
def __hash__(self):
|
|
55
|
+
return self.lineageid
|
|
56
|
+
|
|
57
|
+
def __eq__(self, other):
|
|
58
|
+
return hasattr(other, 'lineageid') and self.lineageid == other.lineageid
|
|
59
|
+
|
|
60
|
+
def __repr__(self):
|
|
61
|
+
return f'{self.lineageid}:{self.intervals}'
|
|
62
|
+
|
|
63
|
+
def get_dict(self):
|
|
64
|
+
d = self.__dict__.copy()
|
|
65
|
+
if d['up']:
|
|
66
|
+
d['up'] = d['up'].nodeid
|
|
67
|
+
d['down'] = d['down'].nodeid
|
|
68
|
+
return d
|
|
69
|
+
|
|
70
|
+
def toJSON(self):
|
|
71
|
+
return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
|
|
72
|
+
|
|
73
|
+
def __deepcopy__(self, memo):
|
|
74
|
+
if self in memo:
|
|
75
|
+
return memo.get(self)
|
|
76
|
+
dup = type(self)(
|
|
77
|
+
lineageid=deepcopy(self.lineageid, memo),
|
|
78
|
+
# down=deepcopy(self.down, memo),
|
|
79
|
+
# up=deepcopy(self.up, memo),
|
|
80
|
+
intervals=deepcopy(self.intervals, memo)
|
|
81
|
+
)
|
|
82
|
+
memo[self] = dup
|
|
83
|
+
dup.down=deepcopy(self.down, memo)
|
|
84
|
+
dup.up=deepcopy(self.up, memo)
|
|
85
|
+
return dup
|
|
86
|
+
|
|
87
|
+
class Node():
|
|
88
|
+
"""
|
|
89
|
+
Leaf, Coalscent and Recombination are the nodes of the graph
|
|
90
|
+
"""
|
|
91
|
+
def __hash__(self):
|
|
92
|
+
return self.nodeid
|
|
93
|
+
|
|
94
|
+
def __eq__(self, other):
|
|
95
|
+
return hasattr(other, 'nodeid') and self.nodeid == other.nodeid
|
|
96
|
+
|
|
97
|
+
def __repr__(self):
|
|
98
|
+
return f'{self.nodeid}'
|
|
99
|
+
|
|
100
|
+
class Leaf(Node):
|
|
101
|
+
|
|
102
|
+
def __init__(self, nodeid=None, height=None, parent=None, intervals=[(0, 1)], xpos=None):
|
|
103
|
+
self.nodeid = nodeid
|
|
104
|
+
self.height = height
|
|
105
|
+
self.intervals = intervals
|
|
106
|
+
self.parent = parent
|
|
107
|
+
self.xpos = xpos
|
|
108
|
+
|
|
109
|
+
def get_dict(self):
|
|
110
|
+
d = self.__dict__.copy()
|
|
111
|
+
d['parent'] = d['parent'].lineageid
|
|
112
|
+
return d
|
|
113
|
+
|
|
114
|
+
def toJSON(self):
|
|
115
|
+
return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
|
|
116
|
+
|
|
117
|
+
def __deepcopy__(self, memo):
|
|
118
|
+
if self in memo:
|
|
119
|
+
return memo.get(self)
|
|
120
|
+
dup = type(self)(
|
|
121
|
+
nodeid=deepcopy(self.nodeid, memo),
|
|
122
|
+
height=deepcopy(self.height, memo),
|
|
123
|
+
# parent=deepcopy(self.parent, memo),
|
|
124
|
+
intervals=deepcopy(self.intervals, memo),
|
|
125
|
+
xpos=deepcopy(self.xpos, memo)
|
|
126
|
+
)
|
|
127
|
+
memo[self] = dup
|
|
128
|
+
dup.parent = deepcopy(self.parent, memo)
|
|
129
|
+
return dup
|
|
130
|
+
|
|
131
|
+
class Coalescent(Node):
|
|
132
|
+
|
|
133
|
+
def __init__(self, nodeid=None, height=None, children=None, parent=None, xpos=None):
|
|
134
|
+
self.nodeid = nodeid
|
|
135
|
+
self.height = height
|
|
136
|
+
self.children = children
|
|
137
|
+
self.parent = parent
|
|
138
|
+
self.xpos = xpos
|
|
139
|
+
|
|
140
|
+
def get_dict(self):
|
|
141
|
+
d = self.__dict__.copy()
|
|
142
|
+
d['children'] = [c.lineageid for c in d['children']]
|
|
143
|
+
d['parent'] = d['parent'].lineageid
|
|
144
|
+
return d
|
|
145
|
+
|
|
146
|
+
def toJSON(self):
|
|
147
|
+
return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
|
|
148
|
+
|
|
149
|
+
def __deepcopy__(self, memo):
|
|
150
|
+
if self in memo:
|
|
151
|
+
return memo.get(self)
|
|
152
|
+
dup = type(self)(
|
|
153
|
+
nodeid=deepcopy(self.nodeid, memo),
|
|
154
|
+
height=deepcopy(self.height, memo),
|
|
155
|
+
# children=deepcopy(self.children, memo),
|
|
156
|
+
# parent=deepcopy(self.parent, memo),
|
|
157
|
+
xpos=deepcopy(self.xpos, memo)
|
|
158
|
+
)
|
|
159
|
+
memo[self] = dup
|
|
160
|
+
dup.children=[deepcopy(c, memo) for c in self.children]
|
|
161
|
+
dup.parent=deepcopy(self.parent, memo)
|
|
162
|
+
return dup
|
|
163
|
+
|
|
164
|
+
class Recombination(Node):
|
|
165
|
+
|
|
166
|
+
def __init__(self, nodeid=None, height=None, child=None, recomb_point=None, left_parent=None, right_parent=None, xpos=None):
|
|
167
|
+
self.nodeid = nodeid
|
|
168
|
+
self.height = height
|
|
169
|
+
self.child = child
|
|
170
|
+
self.left_parent = left_parent
|
|
171
|
+
self.right_parent = right_parent
|
|
172
|
+
self.recomb_point = recomb_point
|
|
173
|
+
self.xpos = xpos
|
|
174
|
+
|
|
175
|
+
def get_dict(self):
|
|
176
|
+
d = self.__dict__.copy()
|
|
177
|
+
d['left_parent'] = d['left_parent'].lineageid
|
|
178
|
+
d['right_parent'] = d['right_parent'].lineageid
|
|
179
|
+
d['child'] = d['child'].lineageid
|
|
180
|
+
return d
|
|
181
|
+
|
|
182
|
+
def toJSON(self):
|
|
183
|
+
return json.dumps(self, default=self.get_dict, sort_keys=True, indent=4)
|
|
184
|
+
|
|
185
|
+
def __deepcopy__(self, memo):
|
|
186
|
+
if self in memo:
|
|
187
|
+
return memo.get(self)
|
|
188
|
+
dup = type(self)(
|
|
189
|
+
nodeid=deepcopy(self.nodeid, memo),
|
|
190
|
+
height=deepcopy(self.height, memo),
|
|
191
|
+
# child=deepcopy(self.child, memo),
|
|
192
|
+
recomb_point=deepcopy(self.recomb_point, memo),
|
|
193
|
+
# left_parent=deepcopy(self.left_parent, memo),
|
|
194
|
+
# right_parent=deepcopy(self.right_parent, memo),
|
|
195
|
+
xpos=deepcopy(self.xpos, memo)
|
|
196
|
+
)
|
|
197
|
+
memo[self] = dup
|
|
198
|
+
dup.child=deepcopy(self.child, memo)
|
|
199
|
+
dup.left_parent=deepcopy(self.left_parent, memo)
|
|
200
|
+
dup.right_parent=deepcopy(self.right_parent, memo)
|
|
201
|
+
return dup
|
|
202
|
+
|
|
203
|
+
def flatten(list_of_tps):
|
|
204
|
+
|
|
205
|
+
return reduce(lambda ls, ival: ls + list(ival), list_of_tps, [])
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def unflatten(list_of_endpoints):
|
|
209
|
+
|
|
210
|
+
return [ [list_of_endpoints[i], list_of_endpoints[i + 1]]
|
|
211
|
+
for i in range(0, len(list_of_endpoints) - 1, 2)]
|
|
212
|
+
|
|
213
|
+
def merge(query, annot, op):
|
|
214
|
+
|
|
215
|
+
a_endpoints = flatten(query)
|
|
216
|
+
b_endpoints = flatten(annot)
|
|
217
|
+
|
|
218
|
+
assert a_endpoints == sorted(a_endpoints), "not sorted or non-overlaping"
|
|
219
|
+
assert b_endpoints == sorted(b_endpoints), "not sorted or non-overlaping"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
sentinel = max(a_endpoints[-1], b_endpoints[-1]) + 1
|
|
223
|
+
a_endpoints += [sentinel]
|
|
224
|
+
b_endpoints += [sentinel]
|
|
225
|
+
|
|
226
|
+
a_index = 0
|
|
227
|
+
b_index = 0
|
|
228
|
+
|
|
229
|
+
res = []
|
|
230
|
+
|
|
231
|
+
scan = min(a_endpoints[0], b_endpoints[0])
|
|
232
|
+
while scan < sentinel:
|
|
233
|
+
in_a = not ((scan < a_endpoints[a_index]) ^ (a_index % 2))
|
|
234
|
+
in_b = not ((scan < b_endpoints[b_index]) ^ (b_index % 2))
|
|
235
|
+
in_res = op(in_a, in_b)
|
|
236
|
+
|
|
237
|
+
if in_res ^ (len(res) % 2):
|
|
238
|
+
res += [scan]
|
|
239
|
+
if scan == a_endpoints[a_index]:
|
|
240
|
+
a_index += 1
|
|
241
|
+
if scan == b_endpoints[b_index]:
|
|
242
|
+
b_index += 1
|
|
243
|
+
scan = min(a_endpoints[a_index], b_endpoints[b_index])
|
|
244
|
+
|
|
245
|
+
return unflatten(res)
|
|
246
|
+
|
|
247
|
+
def interval_diff(a, b):
|
|
248
|
+
if not (a and b):
|
|
249
|
+
return a and a or b
|
|
250
|
+
return merge(a, b, lambda in_a, in_b: in_a and not in_b)
|
|
251
|
+
|
|
252
|
+
def interval_union(a, b):
|
|
253
|
+
if not (a and b):
|
|
254
|
+
return []
|
|
255
|
+
return merge(a, b, lambda in_a, in_b: in_a or in_b)
|
|
256
|
+
|
|
257
|
+
def interval_intersect(a, b):
|
|
258
|
+
if not (a and b):
|
|
259
|
+
return []
|
|
260
|
+
return merge(a, b, lambda in_a, in_b: in_a and in_b)
|
|
261
|
+
|
|
262
|
+
def interval_sum(intervals):
|
|
263
|
+
return sum(e - s for (s, e) in intervals)
|
|
264
|
+
|
|
265
|
+
def interval_span(intervals):
|
|
266
|
+
starts, ends = zip(*intervals)
|
|
267
|
+
return max(ends) - min(starts)
|
|
268
|
+
|
|
269
|
+
def interval_split(intervals, pos):
|
|
270
|
+
left, right = list(), list()
|
|
271
|
+
for s, e in intervals:
|
|
272
|
+
if pos >= e:
|
|
273
|
+
left.append((s, e))
|
|
274
|
+
elif pos >= s and pos < e:
|
|
275
|
+
left.append((s, pos))
|
|
276
|
+
right.append((pos, e))
|
|
277
|
+
else:
|
|
278
|
+
right.append((s, e))
|
|
279
|
+
return left, right
|
|
280
|
+
|
|
281
|
+
def interval_any_shared_borders(a, b):
|
|
282
|
+
log.info('shared borders')
|
|
283
|
+
flat = flatten(a) + flatten(b)
|
|
284
|
+
return len(flat) > len(set(flat))
|
|
285
|
+
|
|
286
|
+
def _get_neighbors_below(node):
|
|
287
|
+
"""Get nodes connected below this node via child lineages."""
|
|
288
|
+
if type(node) is Coalescent:
|
|
289
|
+
return [c.down for c in node.children]
|
|
290
|
+
if type(node) is Recombination:
|
|
291
|
+
return [node.child.down]
|
|
292
|
+
return []
|
|
293
|
+
|
|
294
|
+
def _get_neighbors_above(node):
|
|
295
|
+
"""Get nodes connected above this node via parent lineages."""
|
|
296
|
+
result = []
|
|
297
|
+
if type(node) is Coalescent:
|
|
298
|
+
if node.parent and node.parent.up:
|
|
299
|
+
result.append(node.parent.up)
|
|
300
|
+
elif type(node) is Recombination:
|
|
301
|
+
if node.left_parent and node.left_parent.up:
|
|
302
|
+
result.append(node.left_parent.up)
|
|
303
|
+
if node.right_parent and node.right_parent.up:
|
|
304
|
+
result.append(node.right_parent.up)
|
|
305
|
+
elif type(node) is Leaf:
|
|
306
|
+
if node.parent and node.parent.up:
|
|
307
|
+
result.append(node.parent.up)
|
|
308
|
+
return result
|
|
309
|
+
|
|
310
|
+
def _dfs_leaf_order(node, visited=None):
|
|
311
|
+
"""DFS from root to get a leaf ordering that groups related leaves.
|
|
312
|
+
Randomizes child visit order at coalescent nodes for multi-start."""
|
|
313
|
+
if visited is None:
|
|
314
|
+
visited = set()
|
|
315
|
+
if node in visited:
|
|
316
|
+
return []
|
|
317
|
+
visited.add(node)
|
|
318
|
+
if type(node) is Leaf:
|
|
319
|
+
return [node]
|
|
320
|
+
leaves = []
|
|
321
|
+
if type(node) is Coalescent:
|
|
322
|
+
children = list(node.children)
|
|
323
|
+
shuffle(children)
|
|
324
|
+
for child in children:
|
|
325
|
+
leaves.extend(_dfs_leaf_order(child.down, visited))
|
|
326
|
+
elif type(node) is Recombination:
|
|
327
|
+
leaves.extend(_dfs_leaf_order(node.child.down, visited))
|
|
328
|
+
return leaves
|
|
329
|
+
|
|
330
|
+
####################################################
|
|
331
|
+
|
|
332
|
+
def segments_crossing(line1, line2):
|
|
333
|
+
xdiff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0])
|
|
334
|
+
ydiff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1])
|
|
335
|
+
|
|
336
|
+
def det(a, b):
|
|
337
|
+
return a[0] * b[1] - a[1] * b[0]
|
|
338
|
+
|
|
339
|
+
div = det(xdiff, ydiff)
|
|
340
|
+
if div == 0:
|
|
341
|
+
return False
|
|
342
|
+
# raise Exception('lines do not intersect')
|
|
343
|
+
|
|
344
|
+
d = (det(*line1), det(*line2))
|
|
345
|
+
x = det(d, xdiff) / div
|
|
346
|
+
y = det(d, ydiff) / div
|
|
347
|
+
|
|
348
|
+
# return x, y
|
|
349
|
+
|
|
350
|
+
return \
|
|
351
|
+
min(line1[0][0], line1[1][0]) < x < max(line1[0][0], line1[1][0]) and \
|
|
352
|
+
min(line2[0][0], line2[1][0]) < x < max(line2[0][0], line2[1][0]) and \
|
|
353
|
+
min(line1[0][1], line1[1][1]) < y < max(line1[0][1], line1[1][1]) and \
|
|
354
|
+
min(line2[0][1], line2[1][1]) < y < max(line2[0][1], line2[1][1])
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def crossing(a, b):
|
|
358
|
+
return segments_crossing(
|
|
359
|
+
((a.down.xpos, a.down.height), (a.up.xpos, a.up.height)),
|
|
360
|
+
((b.down.xpos, b.down.height), (b.up.xpos, b.up.height))
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def get_all_crossing_pairs(lineages):
|
|
364
|
+
crossing_pairs = []
|
|
365
|
+
for i in range(len(lineages)):
|
|
366
|
+
for j in range(i+1, len(lineages)):
|
|
367
|
+
if crossing(lineages[i], lineages[j]):
|
|
368
|
+
crossing_pairs.append([i, j])
|
|
369
|
+
return crossing_pairs
|
|
370
|
+
|
|
371
|
+
def reduce_crossovers(nodes):
|
|
372
|
+
lineages = get_parent_lineages(nodes)
|
|
373
|
+
lineages = [x for x in lineages if x.up is not None]
|
|
374
|
+
crossing_pairs = get_all_crossing_pairs(lineages)
|
|
375
|
+
|
|
376
|
+
for i, j in crossing_pairs:
|
|
377
|
+
# swap downs
|
|
378
|
+
lineages[i].down.xpos, lineages[j].down.xpos = lineages[j].down.xpos, lineages[i].down.xpos
|
|
379
|
+
new_crossing_pairs = get_all_crossing_pairs(lineages)
|
|
380
|
+
if not len(new_crossing_pairs) < len(crossing_pairs):
|
|
381
|
+
# swap downs back
|
|
382
|
+
lineages[i].down.xpos, lineages[j].down.xpos = lineages[j].down.xpos, lineages[i].down.xpos
|
|
383
|
+
else:
|
|
384
|
+
crossing_pairs = new_crossing_pairs
|
|
385
|
+
|
|
386
|
+
# swap ups
|
|
387
|
+
lineages[i].up.xpos, lineages[j].up.xpos = lineages[j].up.xpos, lineages[i].up.xpos
|
|
388
|
+
new_crossing_pairs = get_all_crossing_pairs(lineages)
|
|
389
|
+
if not len(new_crossing_pairs) < len(crossing_pairs):
|
|
390
|
+
# swap ups back
|
|
391
|
+
lineages[i].up.xpos, lineages[j].up.xpos = lineages[j].up.xpos, lineages[i].up.xpos
|
|
392
|
+
else:
|
|
393
|
+
crossing_pairs = new_crossing_pairs
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
leaf_idx = [i for i, n in enumerate(nodes) if type(n) is Leaf]
|
|
397
|
+
|
|
398
|
+
if type(lineages[i].down) is Leaf and type(lineages[j].down) is not Leaf:
|
|
399
|
+
# try to find new position for leaf i
|
|
400
|
+
|
|
401
|
+
for k in leaf_idx:
|
|
402
|
+
# swap
|
|
403
|
+
lineages[i].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[i].down.xpos
|
|
404
|
+
|
|
405
|
+
new_crossing_pairs = get_all_crossing_pairs(lineages)
|
|
406
|
+
if not len(new_crossing_pairs) < len(crossing_pairs):
|
|
407
|
+
# swap back
|
|
408
|
+
lineages[i].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[i].down.xpos
|
|
409
|
+
else:
|
|
410
|
+
break
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
if type(lineages[i].down) is Leaf and type(lineages[j].down) is not Leaf:
|
|
414
|
+
# try to find new position for leaf j
|
|
415
|
+
for k in leaf_idx:
|
|
416
|
+
# swap
|
|
417
|
+
lineages[j].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[j].down.xpos
|
|
418
|
+
|
|
419
|
+
new_crossing_pairs = get_all_crossing_pairs(lineages)
|
|
420
|
+
if not len(new_crossing_pairs) < len(crossing_pairs):
|
|
421
|
+
# swap back
|
|
422
|
+
lineages[j].down.xpos, lineages[k].down.xpos = lineages[k].down.xpos, lineages[j].down.xpos
|
|
423
|
+
else:
|
|
424
|
+
break
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
# TODO: if cross over involves an internal and an external branch, see if
|
|
428
|
+
# you can move the leaf to a new position that removes the crossover
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
# leaf_idx = [i for i, n in enumerate(nodes) if type(n) is Leaf]
|
|
433
|
+
# suffled_leaf_idx = leaf_idx[:]
|
|
434
|
+
# shuffle(suffled_leaf_idx)
|
|
435
|
+
# for i, j in sorted(zip(leaf_idx, suffled_leaf_idx)):
|
|
436
|
+
# print(i, j)
|
|
437
|
+
# nodes[i].xpos = nodes[j].xpos
|
|
438
|
+
|
|
439
|
+
# new_crossing_pairs = get_all_crossing_pairs(lineages)
|
|
440
|
+
# if not len(new_crossing_pairs) < len(crossing_pairs):
|
|
441
|
+
# for i, j in sorted(zip(leaf_idx, suffled_leaf_idx), reverse=True):
|
|
442
|
+
# nodes[i].xpos = nodes[j].xpos
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def branch_length(lineage):
|
|
448
|
+
return ((lineage.down.xpos - lineage.up.xpos)**2 + (lineage.down.height - lineage.up.height)**2)**0.5
|
|
449
|
+
|
|
450
|
+
def reduce_total_branch_length(nodes, shift):
|
|
451
|
+
lineages = get_parent_lineages(nodes)
|
|
452
|
+
lineages = [x for x in lineages if x.up is not None]
|
|
453
|
+
|
|
454
|
+
cur_tot = sum(branch_length(l) for l in lineages)
|
|
455
|
+
for i in range(len(lineages)):
|
|
456
|
+
orig_xpos = lineages[i].up.xpos
|
|
457
|
+
lineages[i].up.xpos += shift
|
|
458
|
+
tot = sum(branch_length(l) for l in lineages)
|
|
459
|
+
if tot < cur_tot:
|
|
460
|
+
cur_tot = tot
|
|
461
|
+
else:
|
|
462
|
+
lineages[i].up.xpos = orig_xpos
|
|
463
|
+
lineages[i].up.xpos -= shift
|
|
464
|
+
tot = sum(branch_length(l) for l in lineages)
|
|
465
|
+
if tot < cur_tot:
|
|
466
|
+
cur_tot = tot
|
|
467
|
+
else:
|
|
468
|
+
lineages[i].up.xpos = orig_xpos
|
|
469
|
+
|
|
470
|
+
def redistribute_leaves(nodes):
|
|
471
|
+
|
|
472
|
+
leaves = [n for n in nodes if type(n) is Leaf]
|
|
473
|
+
# min_x = min(n.xpos for n in leaves)
|
|
474
|
+
# max_x = max(n.xpos for n in leaves)
|
|
475
|
+
leaves = sorted(leaves, key=lambda x: x.xpos)
|
|
476
|
+
min_x = 0
|
|
477
|
+
max_x = 1
|
|
478
|
+
for xpos, leaf in zip(np.linspace(min_x, max_x, len(leaves)), leaves):
|
|
479
|
+
leaf.xpos = xpos
|
|
480
|
+
|
|
481
|
+
####################################################
|
|
482
|
+
|
|
483
|
+
def _bottom_up_barycenter(nodes):
|
|
484
|
+
"""Single bottom-up pass: each node at mean of its children's x."""
|
|
485
|
+
non_leaves = sorted(
|
|
486
|
+
[n for n in nodes if type(n) is not Leaf],
|
|
487
|
+
key=lambda n: n.height
|
|
488
|
+
)
|
|
489
|
+
for node in non_leaves:
|
|
490
|
+
below = _get_neighbors_below(node)
|
|
491
|
+
placed = [n for n in below if n.xpos is not None]
|
|
492
|
+
if placed:
|
|
493
|
+
node.xpos = sum(n.xpos for n in placed) / len(placed)
|
|
494
|
+
else:
|
|
495
|
+
node.xpos = 0.5
|
|
496
|
+
|
|
497
|
+
def _count_crossings(nodes):
|
|
498
|
+
"""Count total edge crossings in current layout."""
|
|
499
|
+
lineages = get_parent_lineages(nodes)
|
|
500
|
+
lineages = [x for x in lineages if x.up is not None]
|
|
501
|
+
return len(get_all_crossing_pairs(lineages))
|
|
502
|
+
|
|
503
|
+
def add_node_x_positions(nodes, n_restarts=10):
|
|
504
|
+
"""
|
|
505
|
+
Position nodes to minimize crossings using:
|
|
506
|
+
1. DFS leaf ordering (groups related leaves)
|
|
507
|
+
2. Bottom-up barycenter (single pass, no funnel)
|
|
508
|
+
3. Rescale + crossing reduction post-processing
|
|
509
|
+
|
|
510
|
+
Multi-start: tries n_restarts random DFS orderings, keeps best.
|
|
511
|
+
"""
|
|
512
|
+
best_crossings = float('inf')
|
|
513
|
+
best_positions = None
|
|
514
|
+
|
|
515
|
+
for _ in range(n_restarts):
|
|
516
|
+
# Step 1: DFS leaf ordering (randomized child order)
|
|
517
|
+
leaf_order = _dfs_leaf_order(nodes[-1])
|
|
518
|
+
seen = set(id(l) for l in leaf_order)
|
|
519
|
+
for n in nodes:
|
|
520
|
+
if type(n) is Leaf and id(n) not in seen:
|
|
521
|
+
leaf_order.append(n)
|
|
522
|
+
|
|
523
|
+
# Step 2: Fix leaves evenly spaced
|
|
524
|
+
n_leaves = len(leaf_order)
|
|
525
|
+
for i, leaf in enumerate(leaf_order):
|
|
526
|
+
leaf.xpos = i / max(n_leaves - 1, 1)
|
|
527
|
+
|
|
528
|
+
# Step 3: Bottom-up barycenter
|
|
529
|
+
_bottom_up_barycenter(nodes)
|
|
530
|
+
|
|
531
|
+
# Step 4: Rescale to [0, 1]
|
|
532
|
+
rescale_positions(nodes)
|
|
533
|
+
|
|
534
|
+
# Step 5: Re-even leaves, reduce crossings
|
|
535
|
+
redistribute_leaves(nodes)
|
|
536
|
+
for _ in range(5):
|
|
537
|
+
reduce_crossovers(nodes)
|
|
538
|
+
|
|
539
|
+
# Evaluate
|
|
540
|
+
crossings = _count_crossings(nodes)
|
|
541
|
+
if crossings < best_crossings:
|
|
542
|
+
best_crossings = crossings
|
|
543
|
+
best_positions = {n.nodeid: n.xpos for n in nodes}
|
|
544
|
+
|
|
545
|
+
if best_crossings == 0:
|
|
546
|
+
break
|
|
547
|
+
|
|
548
|
+
# Restore best positions
|
|
549
|
+
if best_positions:
|
|
550
|
+
for n in nodes:
|
|
551
|
+
n.xpos = best_positions[n.nodeid]
|
|
552
|
+
|
|
553
|
+
# Final refinement
|
|
554
|
+
shift_list = [0.5, 0.4, 0.3, 0.2, 0.1, 0.05, 0.02]
|
|
555
|
+
for shift in shift_list:
|
|
556
|
+
reduce_total_branch_length(nodes, shift)
|
|
557
|
+
rescale_positions(nodes)
|
|
558
|
+
for _ in range(5):
|
|
559
|
+
reduce_crossovers(nodes)
|
|
560
|
+
redistribute_leaves(nodes)
|
|
561
|
+
|
|
562
|
+
def get_arg_nodes(n=5, N=10000, r=1e-8, L=5e3, simulation="arg"):
|
|
563
|
+
"""
|
|
564
|
+
Simulates an ARG
|
|
565
|
+
"""
|
|
566
|
+
# because we use the sequence interval from 0 to 1
|
|
567
|
+
r = r * L
|
|
568
|
+
|
|
569
|
+
# print('\nSIM:')
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
assert simulation in ["arg", "smcprime", "smc"]
|
|
573
|
+
|
|
574
|
+
nodes = list()
|
|
575
|
+
live = list() # live lineages
|
|
576
|
+
|
|
577
|
+
for i in range(n):
|
|
578
|
+
leaf = Leaf(i, height=0)
|
|
579
|
+
nodes.append(leaf)
|
|
580
|
+
lin = Lineage(lineageid=i, down=leaf, intervals=[(0, 1)])
|
|
581
|
+
leaf.parent = lin
|
|
582
|
+
live.append(lin)
|
|
583
|
+
last_node = i # max node number used so far
|
|
584
|
+
last_lineage = i
|
|
585
|
+
|
|
586
|
+
while len(live) > 1:
|
|
587
|
+
shuffle(live)
|
|
588
|
+
|
|
589
|
+
coal_prob = (len(live) * len(live)-1) / 2 / (2*N)
|
|
590
|
+
tot_ancestral = sum(interval_sum(x.intervals) for x in live)
|
|
591
|
+
rec_prob = r * tot_ancestral
|
|
592
|
+
wating_time = exponential(1 / (coal_prob + rec_prob))
|
|
593
|
+
height = wating_time + nodes[-1].height
|
|
594
|
+
if random() < coal_prob / (coal_prob + rec_prob):
|
|
595
|
+
# coalescence
|
|
596
|
+
|
|
597
|
+
# if full arg:
|
|
598
|
+
a, b = 0, 1
|
|
599
|
+
|
|
600
|
+
if simulation == "smcprime":
|
|
601
|
+
# intervals must overlap or be adjacent:
|
|
602
|
+
while not (interval_any_shared_borders(live[a].intervals, live[b].intervals) or \
|
|
603
|
+
interval_intersect(live[a].intervals, live[b].intervals)):
|
|
604
|
+
shuffle(live)
|
|
605
|
+
elif simulation == "smc":
|
|
606
|
+
# intervals must overlap:
|
|
607
|
+
while not interval_intersect(live[a].intervals, live[b].intervals):
|
|
608
|
+
shuffle(live)
|
|
609
|
+
|
|
610
|
+
lin_a, lin_b = live.pop(0), live.pop(0)
|
|
611
|
+
|
|
612
|
+
intervals = interval_union(lin_a.intervals, lin_b.intervals)
|
|
613
|
+
|
|
614
|
+
# new node
|
|
615
|
+
if lin_a.down is lin_b.down:
|
|
616
|
+
# make sure diamond is symmetric
|
|
617
|
+
children=[lin_a.down.left_parent, lin_b.down.right_parent]
|
|
618
|
+
else:
|
|
619
|
+
children=[lin_a, lin_b]
|
|
620
|
+
node_c = Coalescent(nodeid=last_node+1, height=height,
|
|
621
|
+
children=children)
|
|
622
|
+
|
|
623
|
+
last_node += 1
|
|
624
|
+
nodes.append(node_c)
|
|
625
|
+
|
|
626
|
+
# add node to top of coalescing lineages
|
|
627
|
+
lin_a.up = node_c
|
|
628
|
+
lin_b.up = node_c
|
|
629
|
+
|
|
630
|
+
# new lineage
|
|
631
|
+
lin_c = Lineage(lineageid=last_lineage+1, down=node_c, intervals=intervals) # fixme intervals
|
|
632
|
+
last_lineage += 1
|
|
633
|
+
live.append(lin_c)
|
|
634
|
+
node_c.parent = lin_c
|
|
635
|
+
else:
|
|
636
|
+
# recombination
|
|
637
|
+
|
|
638
|
+
# rec_lin = choices(live, weights=[interval_sum(x.intervals) for x in live], k=1)[0]
|
|
639
|
+
rec_lin = choices(live, weights=[interval_span(x.intervals) for x in live], k=1)[0]
|
|
640
|
+
live.remove(rec_lin)
|
|
641
|
+
|
|
642
|
+
# total ancestral material
|
|
643
|
+
total_anc = interval_sum(rec_lin.intervals)
|
|
644
|
+
|
|
645
|
+
# recombination point in ancestral material
|
|
646
|
+
recomb_point_anc = random() * total_anc
|
|
647
|
+
|
|
648
|
+
# recombination point in full sequence
|
|
649
|
+
cum = 0
|
|
650
|
+
for s, e in rec_lin.intervals:
|
|
651
|
+
cum += e - s
|
|
652
|
+
if cum > recomb_point_anc:
|
|
653
|
+
recomb_point = e - (cum - recomb_point_anc)
|
|
654
|
+
break
|
|
655
|
+
|
|
656
|
+
# recombination node
|
|
657
|
+
rec_node = Recombination(nodeid=last_node+1, height=height,
|
|
658
|
+
child=rec_lin, recomb_point=recomb_point)
|
|
659
|
+
last_node += 1
|
|
660
|
+
nodes.append(rec_node)
|
|
661
|
+
|
|
662
|
+
# two new lineages both refering back to recombination node
|
|
663
|
+
intervals_a, intervals_b = interval_split(rec_lin.intervals, recomb_point)
|
|
664
|
+
assert interval_sum(intervals_a) and interval_sum(intervals_b)
|
|
665
|
+
assert sum(e-s for s, e in intervals_a) + sum(e-s for s, e in intervals_b) == sum(e-s for s, e in rec_lin.intervals)
|
|
666
|
+
lin_a = Lineage(lineageid=last_lineage+1, down=rec_node, intervals=intervals_a)
|
|
667
|
+
last_lineage += 1
|
|
668
|
+
lin_b = Lineage(lineageid=last_lineage+1, down=rec_node, intervals=intervals_b)
|
|
669
|
+
last_lineage += 1
|
|
670
|
+
live.append(lin_a)
|
|
671
|
+
live.append(lin_b)
|
|
672
|
+
|
|
673
|
+
# add parents of node
|
|
674
|
+
rec_node.left_parent = lin_a
|
|
675
|
+
rec_node.right_parent = lin_b
|
|
676
|
+
|
|
677
|
+
# add parent for child
|
|
678
|
+
rec_lin.up = rec_node
|
|
679
|
+
|
|
680
|
+
add_node_x_positions(nodes)
|
|
681
|
+
|
|
682
|
+
return nodes
|
|
683
|
+
|
|
684
|
+
def _rescale_layout(pos, scale=1):
|
|
685
|
+
# rescale to (0,pscale) in all axes
|
|
686
|
+
|
|
687
|
+
# shift origin to (0,0)
|
|
688
|
+
lim = 0 # max coordinate for all axes
|
|
689
|
+
for i in range(pos.shape[1]):
|
|
690
|
+
pos[:, i] -= pos[:, i].min()
|
|
691
|
+
lim = max(pos[:, i].max(), lim)
|
|
692
|
+
# rescale to (0,scale) in all directions, preserves aspect
|
|
693
|
+
for i in range(pos.shape[1]):
|
|
694
|
+
pos[:, i] *= scale / lim
|
|
695
|
+
return pos
|
|
696
|
+
|
|
697
|
+
def get_positions(nodes):
|
|
698
|
+
"""
|
|
699
|
+
Gets list of x,y positions for nodes
|
|
700
|
+
"""
|
|
701
|
+
positions = list()
|
|
702
|
+
for node in nodes:
|
|
703
|
+
positions.append((node.xpos, node.height))
|
|
704
|
+
return positions
|
|
705
|
+
|
|
706
|
+
def get_breakpoints(nodes):
|
|
707
|
+
"""
|
|
708
|
+
Gets list of recombination break points
|
|
709
|
+
"""
|
|
710
|
+
# get list of break points
|
|
711
|
+
breakpoints = list()
|
|
712
|
+
for n in nodes:
|
|
713
|
+
if type(n) is Recombination:
|
|
714
|
+
# if type(n) is Recombination and all(p.up in nodes for p in n.parents):
|
|
715
|
+
breakpoints.append(n.recomb_point)
|
|
716
|
+
return sorted(breakpoints)
|
|
717
|
+
|
|
718
|
+
def get_parent_lineages(nodes, root=True):
|
|
719
|
+
"""
|
|
720
|
+
Get list of parent lineages of nodes ordered by id
|
|
721
|
+
"""
|
|
722
|
+
|
|
723
|
+
lineages = list()
|
|
724
|
+
for node in nodes:
|
|
725
|
+
if type(node) is Coalescent:
|
|
726
|
+
lineages.append(node.parent)
|
|
727
|
+
if type(node) is Recombination:
|
|
728
|
+
lineages.append(node.left_parent)
|
|
729
|
+
lineages.append(node.right_parent)
|
|
730
|
+
if type(node) is Leaf:
|
|
731
|
+
lineages.append(node.parent)
|
|
732
|
+
|
|
733
|
+
# get unique and sort them so lineageid matches index
|
|
734
|
+
lineages = list(set(lineages))
|
|
735
|
+
lineages.sort(key=lambda x: x.lineageid)
|
|
736
|
+
|
|
737
|
+
if not root:
|
|
738
|
+
# remove dangling root lineage
|
|
739
|
+
lineages = lineages[:-1]
|
|
740
|
+
|
|
741
|
+
return lineages
|
|
742
|
+
|
|
743
|
+
def get_child_lineages(nodes):
|
|
744
|
+
"""
|
|
745
|
+
Get list of child lineages of nodes ordered by id
|
|
746
|
+
"""
|
|
747
|
+
|
|
748
|
+
lineages = list()
|
|
749
|
+
for node in nodes:
|
|
750
|
+
if type(node) is Coalescent:
|
|
751
|
+
lineages.extend(node.children)
|
|
752
|
+
if type(node) is Recombination:
|
|
753
|
+
lineages.append(node.child)
|
|
754
|
+
|
|
755
|
+
# get unique and sort them so lineageid matches index
|
|
756
|
+
lineages = list(set(lineages))
|
|
757
|
+
lineages.sort(key=lambda x: x.lineageid)
|
|
758
|
+
|
|
759
|
+
return lineages
|
|
760
|
+
|
|
761
|
+
# def traverse_marginal(node, interval):
|
|
762
|
+
# """
|
|
763
|
+
# Recursive function for getting marginal tree/ARG
|
|
764
|
+
# """
|
|
765
|
+
# node = deepcopy(node) # TODO: remove if input is a cloned arg.
|
|
766
|
+
# tree_nodes = set()
|
|
767
|
+
# if type(node) is Leaf:
|
|
768
|
+
# tree_nodes.add(node)
|
|
769
|
+
# if type(node) is Recombination:
|
|
770
|
+
# if interval_intersect([interval], node.child.intervals):
|
|
771
|
+
# tree_nodes.add(node)
|
|
772
|
+
# tree_nodes.update(traverse_marginal(node.child.down, interval))
|
|
773
|
+
# elif type(node) is Coalescent:
|
|
774
|
+
# if node.parent is None or interval_intersect([interval], node.parent.intervals):
|
|
775
|
+
# tree_nodes.add(node)
|
|
776
|
+
# del_child = None
|
|
777
|
+
# for i, child in enumerate(node.children):
|
|
778
|
+
# if interval_intersect([interval], child.intervals):
|
|
779
|
+
# tree_nodes.update(traverse_marginal(child.down, interval))
|
|
780
|
+
# else:
|
|
781
|
+
# del_child = i
|
|
782
|
+
# if del_child is not None:
|
|
783
|
+
# del node.children[del_child]
|
|
784
|
+
# return tree_nodes
|
|
785
|
+
|
|
786
|
+
def _traverse_marginal(node, interval):
|
|
787
|
+
"""
|
|
788
|
+
Recursive function for getting marginal tree/ARG
|
|
789
|
+
"""
|
|
790
|
+
tree_nodes = set()
|
|
791
|
+
if type(node) is Leaf:
|
|
792
|
+
tree_nodes.add(node)
|
|
793
|
+
if type(node) is Recombination:
|
|
794
|
+
if interval_intersect([interval], node.child.intervals):
|
|
795
|
+
tree_nodes.add(node)
|
|
796
|
+
tree_nodes.update(_traverse_marginal(node.child.down, interval))
|
|
797
|
+
elif type(node) is Coalescent:
|
|
798
|
+
if node.parent is None or interval_intersect([interval], node.parent.intervals):
|
|
799
|
+
tree_nodes.add(node)
|
|
800
|
+
del_child = None
|
|
801
|
+
for i, child in enumerate(node.children):
|
|
802
|
+
if interval_intersect([interval], child.intervals):
|
|
803
|
+
tree_nodes.update(_traverse_marginal(child.down, interval))
|
|
804
|
+
else:
|
|
805
|
+
del_child = i
|
|
806
|
+
if del_child is not None:
|
|
807
|
+
del node.children[del_child]
|
|
808
|
+
return tree_nodes
|
|
809
|
+
|
|
810
|
+
def traverse_marginal(node, interval):
|
|
811
|
+
clone = deepcopy(node) # TODO: remove if input is a cloned arg.
|
|
812
|
+
return _traverse_marginal(clone, interval)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def remove_dangling_root(tree_nodes):
|
|
816
|
+
"""
|
|
817
|
+
Remove the nodes of a marginal tree or ARG
|
|
818
|
+
from root to first coalescence
|
|
819
|
+
"""
|
|
820
|
+
for_del = list()
|
|
821
|
+
for i in range(len(tree_nodes)-1, -1, -1):
|
|
822
|
+
if type(tree_nodes[i]) is Coalescent and len(tree_nodes[i].children) == 2:
|
|
823
|
+
break
|
|
824
|
+
for_del.append(i)
|
|
825
|
+
for i in for_del:
|
|
826
|
+
del tree_nodes[i]
|
|
827
|
+
|
|
828
|
+
def marginal_arg(nodes, interval, strip_dangling_root=True):
|
|
829
|
+
"""
|
|
830
|
+
Gets the marginal ARG given a sequene interval
|
|
831
|
+
"""
|
|
832
|
+
|
|
833
|
+
# TODO:
|
|
834
|
+
# nodes = clone_arg(nodes)
|
|
835
|
+
|
|
836
|
+
# get nodes for marginal
|
|
837
|
+
marg_nodes = traverse_marginal(nodes[-1], list(interval))
|
|
838
|
+
|
|
839
|
+
# set to list
|
|
840
|
+
marg_nodes = list(marg_nodes)
|
|
841
|
+
|
|
842
|
+
# sort on height
|
|
843
|
+
marg_nodes.sort(key=lambda x: x.height)
|
|
844
|
+
# prune top path above last coalescence
|
|
845
|
+
if strip_dangling_root:
|
|
846
|
+
remove_dangling_root(marg_nodes)
|
|
847
|
+
|
|
848
|
+
return marg_nodes
|
|
849
|
+
|
|
850
|
+
def marginal_trees(nodes, interval, strip_dangling_root=False):
|
|
851
|
+
"""
|
|
852
|
+
Gets list of marginal trees
|
|
853
|
+
"""
|
|
854
|
+
tree_list = list()
|
|
855
|
+
breakpoints = get_breakpoints(nodes)
|
|
856
|
+
|
|
857
|
+
borders = [0] + breakpoints + [1]
|
|
858
|
+
|
|
859
|
+
# for interval in zip(borders[:-1], borders[1:]):
|
|
860
|
+
# marg_nodes = marginal_arg(nodes, interval)
|
|
861
|
+
# tree_list.append(marg_nodes)
|
|
862
|
+
# return tree_list
|
|
863
|
+
|
|
864
|
+
interval_list = list()
|
|
865
|
+
for interv in zip(borders[:-1], borders[1:]):
|
|
866
|
+
marg_nodes = marginal_arg(nodes, interv, strip_dangling_root=strip_dangling_root)
|
|
867
|
+
# TODO: make this correct:...
|
|
868
|
+
if marg_nodes and interv[1] > interval[0] and interval[1] > interv[0]:
|
|
869
|
+
tree_list.append(marg_nodes)
|
|
870
|
+
interval_list.append(interv)
|
|
871
|
+
return tree_list, interval_list
|
|
872
|
+
|
|
873
|
+
# def draw_graph(nodes):
|
|
874
|
+
# """
|
|
875
|
+
# Draws graph using matplotlib
|
|
876
|
+
# """
|
|
877
|
+
# # make a graph and positions list
|
|
878
|
+
# positions = list()
|
|
879
|
+
# arg = nx.Graph()
|
|
880
|
+
# for node in nodes:
|
|
881
|
+
# arg.add_node(node.nodeid)
|
|
882
|
+
# positions.append((node.xpos, node.height))
|
|
883
|
+
# if isinstance(node, Recombination):
|
|
884
|
+
# arg.add_edge(node.child.down.nodeid, node.nodeid)
|
|
885
|
+
# elif isinstance(node, Coalescent):
|
|
886
|
+
# for child in node.children:
|
|
887
|
+
# arg.add_edge(child.down.nodeid, node.nodeid)
|
|
888
|
+
|
|
889
|
+
# positions = np.array(positions)
|
|
890
|
+
# positions = dict(zip(arg.nodes(), positions))
|
|
891
|
+
|
|
892
|
+
# #pos = nx.spring_layout(arg)
|
|
893
|
+
# nx.draw(arg, positions, alpha=0.5, node_size=200, with_labels=True)
|
|
894
|
+
# #with_labels=True,
|
|
895
|
+
# #connectionstyle='arc3, rad = 0.1',
|
|
896
|
+
# #arrowstyle='-')
|
|
897
|
+
|
|
898
|
+
# # G = nx.DiGraph() #or G = nx.MultiDiGraph()
|
|
899
|
+
# # G.add_node('A')
|
|
900
|
+
# # G.add_node('B')
|
|
901
|
+
# # G.add_edge('A', 'B', length = 2)
|
|
902
|
+
# # G.add_edge('B', 'A', length = 3)
|
|
903
|
+
# # pos = nx.spring_layout(G)
|
|
904
|
+
# # nx.draw(G, pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
|
|
905
|
+
# # edge_labels=dict([((u,v,),d['length'])
|
|
906
|
+
# # for u,v,d in G.edges(data=True)])
|
|
907
|
+
|
|
908
|
+
# plt.show()
|
|
909
|
+
|
|
910
|
+
def rescale_positions(nodes):
|
|
911
|
+
"""
|
|
912
|
+
Rescales xpos and heights to 0,1 range
|
|
913
|
+
"""
|
|
914
|
+
max_height = max(n.height for n in nodes)
|
|
915
|
+
max_xpos = max(n.xpos for n in nodes)
|
|
916
|
+
min_xpos = min(n.xpos for n in nodes)
|
|
917
|
+
for node in nodes:
|
|
918
|
+
node.xpos = (node.xpos - min_xpos) / (max_xpos-min_xpos)
|
|
919
|
+
node.height = node.height / max_height
|
|
920
|
+
|
|
921
|
+
def arg2json(nodes):
|
|
922
|
+
|
|
923
|
+
lineages = get_parent_lineages(nodes)
|
|
924
|
+
|
|
925
|
+
json_data = dict(Coalescent=[], Recombination=[], Leaf=[],
|
|
926
|
+
Lineage=[l.get_dict() for l in lineages])
|
|
927
|
+
for node in nodes:
|
|
928
|
+
json_data[node.__class__.__name__].append(node.get_dict())
|
|
929
|
+
|
|
930
|
+
json_data = compute_arg_xpos(json_data)
|
|
931
|
+
|
|
932
|
+
return json.dumps(json_data, indent=4)
|
|
933
|
+
|
|
934
|
+
def json2arg(json_str):
|
|
935
|
+
|
|
936
|
+
nodes = list()
|
|
937
|
+
|
|
938
|
+
data = json.loads(json_str)
|
|
939
|
+
|
|
940
|
+
# make lineages (with indexes instead of node refs)
|
|
941
|
+
lineages = [Lineage(**data) for data in data['Lineage']]
|
|
942
|
+
|
|
943
|
+
for node_data in data['Coalescent']:
|
|
944
|
+
# make the node
|
|
945
|
+
node = Coalescent(**node_data)
|
|
946
|
+
# populate the parent and children with actual lineages
|
|
947
|
+
node.parent = lineages[node.parent]
|
|
948
|
+
node.children = [lineages[i] for i in node.children]
|
|
949
|
+
nodes.append(node)
|
|
950
|
+
|
|
951
|
+
for node_data in data['Recombination']:
|
|
952
|
+
# make the node
|
|
953
|
+
node = Recombination(**node_data)
|
|
954
|
+
# populate the parents and child with actual lineages
|
|
955
|
+
node.left_parent = lineages[node.left_parent]
|
|
956
|
+
node.right_parent = lineages[node.right_parent]
|
|
957
|
+
node.child = lineages[node.child]
|
|
958
|
+
nodes.append(node)
|
|
959
|
+
|
|
960
|
+
for node_data in data['Leaf']:
|
|
961
|
+
# make the node
|
|
962
|
+
node = Leaf(**node_data)
|
|
963
|
+
# populate the parent with actual lineages
|
|
964
|
+
node.parent = lineages[node.parent]
|
|
965
|
+
nodes.append(node)
|
|
966
|
+
|
|
967
|
+
nodes.sort(key=lambda x: x.nodeid)
|
|
968
|
+
|
|
969
|
+
# populate up and down with the nodes
|
|
970
|
+
for lineage in lineages:
|
|
971
|
+
lineage.down = nodes[lineage.down]
|
|
972
|
+
if lineage.up is not None:
|
|
973
|
+
lineage.up = nodes[lineage.up]
|
|
974
|
+
|
|
975
|
+
return nodes
|
|
976
|
+
|
|
977
|
+
if __name__ == '__main__':
|
|
978
|
+
|
|
979
|
+
# get arg and add positions
|
|
980
|
+
nodes = get_arg_nodes()
|
|
981
|
+
|
|
982
|
+
print(nodes)
|
|
983
|
+
|
|
984
|
+
print(deepcopy(nodes))
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
json_str = arg2json(nodes)
|
|
988
|
+
print(json_str)
|
|
989
|
+
retrieved_nodes = json2arg(json_str)
|
|
990
|
+
|
|
991
|
+
print(nodes)
|
|
992
|
+
print(retrieved_nodes)
|
|
993
|
+
print(nodes == retrieved_nodes)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
# get breakpoints
|
|
998
|
+
breakpoints = get_breakpoints(nodes)
|
|
999
|
+
print(breakpoints)
|
|
1000
|
+
|
|
1001
|
+
sys.exit()
|
|
1002
|
+
|
|
1003
|
+
# get marginal trees
|
|
1004
|
+
trees = marginal_trees(nodes)
|
|
1005
|
+
print(trees)
|
|
1006
|
+
|
|
1007
|
+
# marginal arg for some consequtive intervals
|
|
1008
|
+
marg_arg = marginal_arg(nodes, [0, breakpoints[1]])
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
# # draw graphs for testing
|
|
1012
|
+
# draw_graph(nodes)
|
|
1013
|
+
# draw_graph(marg_arg)
|
|
1014
|
+
|
|
1015
|
+
# draw_graph(nodes)
|
|
1016
|
+
# for tree in marginal_trees(nodes):
|
|
1017
|
+
# print([n.xpos for n in tree])
|
|
1018
|
+
# draw_graph(tree)
|
|
1019
|
+
|