arg-dashboard 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,934 @@
1
+ import logging
2
+ logging.basicConfig(level=logging.INFO)
3
+
4
+ from ast import Import
5
+ import plotly.graph_objects as go
6
+
7
+ import json
8
+ from textwrap import dedent as d
9
+
10
+ import dash
11
+ from dash import dcc
12
+ from dash import html
13
+ from dash.dependencies import Input, Output, State
14
+
15
+ import dash_bootstrap_components as dbc
16
+
17
+ import math
18
+ import sys
19
+
20
+ import logging
21
+
22
+ #from app import app
23
+
24
+ # external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
25
+ # external_stylesheets = [dbc.themes.GRID]
26
+ external_stylesheets = [dbc.themes.BOOTSTRAP]
27
+
28
+ # app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
29
+ app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
30
+
31
+ server = app.server
32
+ app.config.suppress_callback_exceptions = True
33
+
34
+ import json
35
+
36
+ ###########################################e
37
+
38
+ import pandas as pd
39
+ import networkx as nx
40
+
41
+ from .arg import (
42
+ Coalescent, Recombination, Leaf,
43
+ interval_sum, interval_diff, interval_intersect,
44
+ get_breakpoints, get_child_lineages,
45
+ arg2json, json2arg,
46
+ get_arg_nodes,
47
+ rescale_positions,
48
+ marginal_arg, traverse_marginal, marginal_trees
49
+ )
50
+ # from arg import Coalescent, Recombination, Leaf, interval_sum, interval_diff, interval_intersect, get_breakpoints, get_child_lineages, rescale_positions, marginal_arg, traverse_marginal, marginal_trees
51
+
52
+ import plotly.colors
53
+
54
+ def get_continuous_color(colorscale, intermed):
55
+ """
56
+ Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
57
+ color for any value in that range.
58
+
59
+ Plotly doesn't make the colorscales directly accessible in a common format.
60
+ Some are ready to use:
61
+
62
+ colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
63
+
64
+ Others are just swatches that need to be constructed into a colorscale:
65
+
66
+ viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
67
+ colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
68
+
69
+ :param colorscale: A plotly continuous colorscale defined with RGB string colors.
70
+ :param intermed: value in the range [0, 1]
71
+ :return: color in rgb string format
72
+ :rtype: str
73
+ """
74
+ if len(colorscale) < 1:
75
+ raise ValueError("colorscale must have at least one color")
76
+
77
+ if intermed <= 0 or len(colorscale) == 1:
78
+ return colorscale[0][1]
79
+ if intermed >= 1:
80
+ return colorscale[-1][1]
81
+
82
+ for cutoff, color in colorscale:
83
+ if intermed > cutoff:
84
+ low_cutoff, low_color = cutoff, color
85
+ else:
86
+ high_cutoff, high_color = cutoff, color
87
+ break
88
+
89
+ # noinspection PyUnboundLocalVariable
90
+ return plotly.colors.find_intermediate_color(
91
+ lowcolor=low_color, highcolor=high_color,
92
+ intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
93
+ colortype="rgb")
94
+
95
+ layout = html.Div(
96
+ [
97
+ # Hidden div inside the app that stores the intermediate value
98
+ html.Div(id='intermediate-value', style={'display': 'none'}),
99
+
100
+ # row for arg and marginal trees
101
+ dbc.Row(
102
+ [
103
+ # column for arg
104
+ dbc.Col(
105
+ [
106
+ # arg
107
+ dbc.Container(
108
+ [
109
+ dbc.Container(
110
+ [
111
+
112
+ dbc.Row(
113
+ [
114
+ dbc.Col(
115
+ [
116
+ html.B("Simulation:"),
117
+ # "Simulation:",
118
+ dcc.Dropdown(
119
+ id='sim-dropdown',
120
+ options=[
121
+ {'label': "ARG", 'value': 'arg'},
122
+ # {'label': "SMC", 'value': 'smc'},
123
+ {'label': "SMC'", 'value': 'smcprime'},
124
+ {'label': "SMC", 'value': 'smc'}
125
+ ],
126
+ value='arg', searchable=False, clearable=False
127
+ ),
128
+ ], width=2
129
+ ),
130
+ dbc.Col(
131
+ [
132
+ html.B("Nr samples:"),
133
+ dcc.Dropdown(
134
+ id='samples-dropdown',
135
+ options=[
136
+ {'label': "3", 'value': 3},
137
+ {'label': "4", 'value': 4},
138
+ {'label': "5", 'value': 5}
139
+ ],
140
+ value=5, searchable=False, clearable=False,
141
+ style={
142
+ # 'height': '20px',
143
+ # 'width': '80px',
144
+ 'font-size': "0.85rem",
145
+ # 'min-height': '1px',
146
+ },
147
+ ),
148
+ ], width=2
149
+ ),
150
+ dbc.Col(
151
+ [
152
+ html.B("Length:"),
153
+ dcc.Dropdown(
154
+ id='seqlen-dropdown',
155
+ options=[
156
+ {'label': "1kb", 'value': 1e+3},
157
+ {'label': "2kb", 'value': 2e+3},
158
+ {'label': "4kb", 'value': 4e+3}
159
+ ],
160
+ value=2e+3, searchable=False, clearable=False,
161
+ style={
162
+ # 'height': '20px',
163
+ # 'width': '80px',
164
+ 'font-size': "0.85rem",
165
+ # 'min-height': '1px',
166
+ },
167
+ ),
168
+ ], width=2
169
+ ),
170
+ dbc.Col(
171
+ [
172
+ # html.Button('New simulation', id='new-arg-button')
173
+ dbc.Button('New', id='new-arg-button',
174
+ color="primary", #size="sm", #outline=True,
175
+ style={'height': 35, 'font-size': "0.85rem"},
176
+ className="mr-1"
177
+ )
178
+ ], width=3
179
+ ),
180
+ dbc.Col(
181
+ [
182
+ html.Div(id='arg-header'),
183
+
184
+ # dcc.Markdown(d("""
185
+ # **Ancestral recombination graph:**
186
+ # Nodes are colored by amount of ancestral sequence.
187
+ # """), ),
188
+ ], width=3
189
+ ),
190
+ ], justify="between", align="end", style={'padding': 3}
191
+ ),
192
+ dcc.Graph(id='arg-figure',
193
+ clear_on_unhover=True,
194
+ figure={'layout': {
195
+ # 'height': 565,
196
+ # 'margin': {'l': 0, 'b': 0, 't': 0, 'r': 0},
197
+ }
198
+ }, style={'height': '85%'}
199
+ ),
200
+ ], className='pretty_container', fluid=True, style={ #'padding': 20,
201
+ 'flex-grow': 1, 'height': '100%'}
202
+ ),
203
+ ], style={#'padding-bottom': 20,
204
+ 'flex-grow': 1, 'height': '100%'}
205
+ ),
206
+ ], width=8,
207
+ style={'display': 'flex', #'height': '650',
208
+ 'flex-direction': 'column',
209
+ # 'background-color': 'blue',
210
+ 'overflow': 'hidden'
211
+ }
212
+ ),
213
+
214
+ # column for marginal trees
215
+ dbc.Col(
216
+ [
217
+ dbc.Row(
218
+ [
219
+ dbc.Container(
220
+ [
221
+ dcc.Markdown(d("""
222
+ **Marginal tree(s):** Hover over an ARG node.
223
+ """), ),
224
+ dcc.Graph(id='marginal-tree',
225
+ figure={'layout': {
226
+ # 'autosize': True,
227
+ # 'title': 'Marginal tree',
228
+ # 'height': 250,
229
+ # 'margin': {'l': 10, 'b': 0, 't': 0, 'r': 0},
230
+ }
231
+ }, style={'height': '85%'}
232
+ ),
233
+ ], className='pretty_container',
234
+ ),
235
+ ], style={#'padding': 10, 'padding-left': 0, 'padding-top': 0,
236
+ 'padding-bottom': 20,
237
+ 'flex-grow': 0.5, 'height': '50%'}
238
+ ),
239
+ dbc.Row(
240
+ [
241
+ dbc.Container(
242
+ [
243
+ dcc.Markdown(d("""
244
+ **Ancestral sequences:** Hover over an ARG node.
245
+ """), ),
246
+ dcc.Graph(id='ancestral-sequence',
247
+ figure={'layout': {
248
+ # 'autosize': True,
249
+ # 'height': 2, 50,
250
+ },
251
+ }, style={'height': '85%'}
252
+ ),
253
+ ], className='pretty_container',
254
+ ),
255
+ ], style={# 'padding': 10, 'padding-left': 0, 'padding-bottom': 0,
256
+ 'flex-grow': 0.5, 'height': '50%' #'height': 300
257
+ # 'display': 'flex', 'flex-direction': 'column', 'justify-content': 'space-between', 'align-items': 'stretch'
258
+ },
259
+ ),
260
+ ], width=4,
261
+ style={'display': 'flex', 'height': '100%',
262
+ 'flex-direction': 'column',
263
+ # 'background-color': 'red',
264
+ 'overflow': 'hidden',
265
+ 'padding-right': 25,
266
+ 'padding-left': 20}
267
+
268
+ ),
269
+ ],
270
+ className="g-0", style={'height': 650, 'padding-bottom': 20, 'padding-top': 20}
271
+ # style={'display': 'flex', 'height': '100%',
272
+ # 'flex-direction': 'column',
273
+ # # 'background-color': 'red', 'overflow': 'hidden'
274
+ # }
275
+ ),
276
+
277
+ dbc.Row(
278
+ [
279
+ dbc.Col(
280
+ [
281
+ dbc.Container(
282
+ [
283
+ dbc.Container(
284
+ [
285
+ dbc.Container(
286
+ [
287
+ dcc.Markdown(d("""
288
+ **Coalesce and recombination events:**
289
+ Slide to see progression of events.
290
+ """)),
291
+ ],
292
+ ),
293
+ dbc.Container(
294
+ [
295
+ dcc.Slider(
296
+ id='event-slider',
297
+ min=0, max=40, value=0,
298
+ marks={str(i): str(i) for i in range(0, 40)},
299
+ step=None,
300
+ ),
301
+ ], #style={'padding-bottom': 20}
302
+ )
303
+ ], className='pretty_container'
304
+ )
305
+ ], #style={'padding': 20, 'padding-top': 0},
306
+ )
307
+ ], width=6
308
+ ),
309
+ dbc.Col(
310
+ [
311
+ dbc.Container(
312
+ [
313
+ dbc.Container(
314
+ [
315
+ dbc.Container(
316
+ [
317
+ dcc.Markdown(d("""
318
+ **Recombination points:**
319
+ Slide to see graph for only part of the sequence.
320
+ """)),
321
+ ]
322
+ ),
323
+ dbc.Container(
324
+ [
325
+ dcc.RangeSlider(
326
+ id='seq-slider',
327
+ min=0,
328
+ max=1000,
329
+ value=[0, 1000],
330
+ # step=None,
331
+ marks={0: '0', 1000: '1'},
332
+ pushable=30,
333
+ )
334
+ ], #style={'padding-bottom': 20}
335
+ ),
336
+ ], className='pretty_container',
337
+ ),
338
+ ], #style={'padding': 20, 'padding-left': 0, 'padding-top': 0}
339
+ ),
340
+ ], width=6, align='start',
341
+ ),
342
+ ],
343
+ className="g-0"
344
+ ),
345
+ ],# style={'padding': 20}
346
+ )
347
+
348
+
349
+ def get_bezier_points(x1, y1, x2, y2, relative, absolute_limits=None):
350
+ mid_x = x1 + (x2 - x1) / 2
351
+ mid_y = y1 + (y2 - y1) / 2
352
+
353
+ dx = x2 - x1
354
+ dy = y2 - y1
355
+ length = math.sqrt(dx**2 + dy**2)
356
+
357
+ if absolute_limits is not None:
358
+ hyp = relative
359
+ else:
360
+ hyp = length * relative
361
+
362
+ if absolute_limits is not None:
363
+ hyp = min(max(hyp, absolute_limits[0]), absolute_limits[1])
364
+
365
+ if length == 0:
366
+ return mid_x, mid_y, mid_x, mid_y
367
+
368
+ # Perpendicular unit vector
369
+ perp_x = -dy / length
370
+ perp_y = dx / length
371
+
372
+ # Control points at distance 'hyp' perpendicular to the line
373
+ b11 = mid_x + perp_x * hyp
374
+ b12 = mid_y + perp_y * hyp
375
+ b21 = mid_x - perp_x * hyp
376
+ b22 = mid_y - perp_y * hyp
377
+
378
+ return b11, b12, b21, b22
379
+
380
+
381
+ def arg_figure_data(nodes):
382
+
383
+ traces = []
384
+
385
+ edge_x = []
386
+ edge_y = []
387
+
388
+ diamond_shapes = []
389
+
390
+ # for lineage in get_parent_lineages(nodes, root=False):
391
+ for lineage in get_child_lineages(nodes):
392
+
393
+ if type(lineage.down) is Recombination and \
394
+ type(lineage.up) is Coalescent and \
395
+ set(lineage.up.children) == set([lineage.down.right_parent, lineage.down.left_parent]):
396
+
397
+ # diamond recombination:
398
+ x1 = lineage.down.xpos
399
+ y1 = lineage.down.height
400
+
401
+ x2 = lineage.up.xpos
402
+ y2 = lineage.up.height
403
+
404
+ b11, b12, b21, b22 = get_bezier_points(x1, y1, x2, y2,
405
+ relative=0.2,
406
+ absolute_limits=(0.005, 0.05))
407
+
408
+ diamond_shapes.append(
409
+ dict(
410
+ type="path",
411
+ path=f"M {x1},{y1} Q {b11},{b12} {x2},{y2}",
412
+ # line_color="lightgray",
413
+ layer='below',
414
+ line= {'width': 2, 'color': 'gray'}
415
+ )
416
+ )
417
+ diamond_shapes.append(
418
+ dict(
419
+ type="path",
420
+ path=f"M {x1},{y1} Q {b21},{b22} {x2},{y2}",
421
+ # line_color="lightgray",
422
+ layer='below',
423
+ line= {'width': 2, 'color': 'gray'},
424
+ )
425
+ )
426
+
427
+ else:
428
+ # start
429
+ edge_x.append(lineage.down.xpos)
430
+ edge_y.append(lineage.down.height)
431
+ # end
432
+ edge_x.append(lineage.up.xpos)
433
+ edge_y.append(lineage.up.height)
434
+ # gap
435
+ edge_x.append(None)
436
+ edge_y.append(None)
437
+
438
+ traces.append(dict(
439
+ x=edge_x,
440
+ y=edge_y,
441
+ mode='lines',
442
+ opacity=1,
443
+ hoverinfo = 'skip',
444
+ line={
445
+ 'color': 'grey',
446
+ },
447
+ name=''
448
+ ))
449
+
450
+ node_x = []
451
+ node_y = []
452
+ node_text = []
453
+ node_color = []
454
+ for node in nodes:
455
+ node_x.append(node.xpos)
456
+ node_y.append(node.height)
457
+ prop_ancestral = 1
458
+ if type(node) is Coalescent:
459
+ prop_ancestral = interval_sum(node.parent.intervals)
460
+ elif type(node) is Recombination:
461
+ prop_ancestral = interval_sum(node.child.intervals)
462
+ node_text.append(f"Fraction ancestral: {round(prop_ancestral, 2)}<br>Event: {type(node).__name__}")
463
+
464
+ node_color.append(prop_ancestral)
465
+
466
+ traces.append(dict(
467
+ x=node_x,
468
+ y=node_y,
469
+ text=node_text,
470
+ # range_color=[0, 1],
471
+ # cmin=0,
472
+ # cmax=1,
473
+ mode='markers',
474
+ opacity=1,
475
+ hoverinfo ='text',
476
+ marker={
477
+ 'size': 10,
478
+ 'color': node_color,
479
+ 'cmin': 0,
480
+ 'cmax': 1,
481
+ 'line': {'width': 0.7, 'color': 'white'},
482
+ 'colorscale': 'Rainbow',
483
+ # 'colorscale': 'Viridis', 'reversescale': True,
484
+ 'colorbar': {'title': 'Fraction<br>ancestral<br>sequence',
485
+ 'titleside': 'top',
486
+ 'thickness': 15,
487
+ 'len': 0.5,
488
+ # 'tickmode': 'array',
489
+ 'tickvals': [0, 0.5, 1],
490
+ # 'ticktext': ['0', '1'],
491
+ 'ticks': 'outside',
492
+ },
493
+ },
494
+ name=''
495
+ ))
496
+
497
+ return dict(data=traces,
498
+ layout=dict(xaxis=dict(fixedrange=True,
499
+ range=[-0.1, 1.1], #title='Samples',
500
+ showgrid=False, showline=False,
501
+ zeroline=False, showticklabels=False
502
+ ),
503
+ yaxis=dict(fixedrange=True,
504
+ range=[-0.1, 1.1], #title='Time',
505
+ showgrid=False, showline=False,
506
+ zeroline=False, showticklabels=False
507
+ ),
508
+ hovermode='closest',
509
+ range_color=[0,1],
510
+ margin= {'l': 50, 'b': 20, 't': 20, 'r': 20},
511
+ transition = {'duration': 0},
512
+ showlegend=False,
513
+ shapes=diamond_shapes,
514
+ )
515
+ )
516
+
517
+
518
+ def tree_figure_data(node_lists):
519
+
520
+ traces = []
521
+
522
+ edge_x = []
523
+ edge_y = []
524
+ node_x = []
525
+ node_y = []
526
+ node_color = []
527
+
528
+ max_x = -1
529
+
530
+ for i, nodes in enumerate(node_lists):
531
+
532
+ # for lineage in get_parent_lineages(nodes, root=False):
533
+ for lineage in get_child_lineages(nodes):
534
+ # start
535
+ edge_x.append(lineage.down.xpos)
536
+ edge_y.append(lineage.down.height)
537
+ # end
538
+ edge_x.append(lineage.up.xpos)
539
+ edge_y.append(lineage.up.height)
540
+ # gap
541
+ edge_x.append(None)
542
+ edge_y.append(None)
543
+
544
+ for node in nodes:
545
+ node_x.append(node.xpos)
546
+ node_y.append(node.height)
547
+
548
+ node_color.append(i/len(node_lists))
549
+
550
+ max_x = max(max_x, node.xpos)
551
+
552
+
553
+
554
+ traces.append(dict(
555
+ x=edge_x,
556
+ y=edge_y,
557
+ mode='lines',
558
+ opacity=1,
559
+ hoverinfo = 'skip',
560
+ line={
561
+ 'color': 'grey',
562
+ },
563
+ name=''
564
+ ))
565
+
566
+ traces.append(dict(
567
+ x=node_x,
568
+ y=node_y,
569
+ mode='markers',
570
+ opacity=1,
571
+ hoverinfo ='text',
572
+ marker={
573
+ 'size': 7,
574
+ 'color': node_color,
575
+ 'cmin': 0,
576
+ 'cmax': 1,
577
+ 'colorscale': 'Rainbow',
578
+ 'line': {'width': 0.3, 'color': 'white'},
579
+ },
580
+ name=''
581
+ ))
582
+
583
+ return dict(data=traces,
584
+ layout=dict(xaxis=dict(fixedrange=True,
585
+ range=[-0.02, max_x + 0.02], #title='Samples',
586
+ showgrid=False, showline=False,
587
+ zeroline=False, showticklabels=False
588
+ ),
589
+ yaxis=dict(fixedrange=True,
590
+ # range=[-0.1, 1.1], #title='Time',
591
+ range=[-0.02, 1.02], #title='Time',
592
+ showgrid=False, showline=False,
593
+ zeroline=False, showticklabels=False
594
+ ),
595
+ hovermode='closest',
596
+ range_color=[0,1],
597
+ margin= {'l': 7, 'b': 10, 't': 10, 'r': 4},
598
+ # margin= {'l': 0, 'b': 0, 't': 0, 'r': 0},
599
+ transition = {'duration': 0},
600
+ showlegend=False,
601
+ )
602
+ )
603
+
604
+
605
+ @app.callback(
606
+ Output('arg-header', 'children'),
607
+ [Input('new-arg-button', 'n_clicks')])
608
+ def update_header(n_clicks):
609
+
610
+ if n_clicks is None:
611
+ n_sim = 1
612
+ else:
613
+ n_sim = n_clicks + 1
614
+
615
+ return dcc.Markdown(d("""
616
+ **Simulation #{}:**
617
+ """.format(n_sim)))
618
+
619
+ @app.callback(Output('intermediate-value', 'children'),
620
+ [Input('new-arg-button', 'n_clicks'),
621
+ Input('sim-dropdown', 'value'),
622
+ Input('samples-dropdown', 'value'),
623
+ Input('seqlen-dropdown', 'value')])
624
+ def new_data(n_clicks, sim, samples, length):
625
+
626
+ nodes = get_arg_nodes(L=length, n=samples, simulation=sim)
627
+ # rescale_positions(nodes)
628
+ json_str = arg2json(nodes)
629
+ return json_str
630
+
631
+ @app.callback(
632
+ [Output(component_id='event-slider', component_property='min'),
633
+ Output(component_id='event-slider', component_property='max'),
634
+ Output(component_id='event-slider', component_property='step'),
635
+ Output(component_id='event-slider', component_property='value')],
636
+ [Input('intermediate-value', 'children')])
637
+ def update_event_slider(jsonified_data):
638
+ if jsonified_data:
639
+ nodes = json2arg(jsonified_data)
640
+ else:
641
+ nodes = []
642
+
643
+ nr_leaves = len([n for n in nodes if type(n) is Leaf])
644
+ nr_events = len(nodes)-nr_leaves
645
+ return 0, nr_events, 1, nr_events
646
+
647
+ @app.callback(
648
+ [Output(component_id='seq-slider', component_property='min'),
649
+ Output(component_id='seq-slider', component_property='max'),
650
+ Output(component_id='seq-slider', component_property='value'),
651
+ Output(component_id='seq-slider', component_property='marks')],
652
+ [Input('intermediate-value', 'children')])
653
+ def update_seq_slider(jsonified_data):
654
+ if jsonified_data:
655
+ nodes = json2arg(jsonified_data)
656
+ else:
657
+ nodes = []
658
+ breakpoints = get_breakpoints(nodes)
659
+ marks = dict((b*1000, str(i+1)) for i, b in enumerate(breakpoints))
660
+ if not marks:
661
+ marks = None
662
+ return 0, 1000, [0, 1000], marks
663
+
664
+
665
+ @app.callback(
666
+ Output('arg-figure', 'figure'),
667
+ [Input('intermediate-value', 'children'),
668
+ Input('event-slider', 'value'),
669
+ Input('seq-slider', 'value')])
670
+ def update_arg_figure(jsonified_data, event, interval):
671
+
672
+ if jsonified_data:
673
+ nodes = json2arg(jsonified_data)
674
+
675
+ interval = [i/1000 for i in interval]
676
+
677
+ # Get marginal arg for interval
678
+ marg_arg_nodes = marginal_arg(nodes, interval)
679
+ # print(interval)
680
+ # get only subset of events
681
+ nr_leaves = len([n for n in nodes if type(n) is Leaf])
682
+ new_nodes = marg_arg_nodes[:nr_leaves+event]
683
+ else:
684
+ new_nodes = []
685
+
686
+ return arg_figure_data(new_nodes)
687
+
688
+
689
+ @app.callback(
690
+ Output('marginal-tree', 'figure'),
691
+ [Input('intermediate-value', 'children'),
692
+ Input('arg-figure', 'hoverData'),
693
+ Input('seq-slider', 'value')])
694
+ def update_marg_tree_figure(jsonified_data, hover, slider_interval):
695
+
696
+ marg_tree_list = []
697
+ if hover and jsonified_data:
698
+ nodes = json2arg(jsonified_data)
699
+ focus_node_idx = hover['points'][0]['pointIndex']
700
+ focus_node = nodes[focus_node_idx]
701
+
702
+ if type(focus_node) is Recombination:
703
+ intervals = focus_node.child.intervals
704
+ else:
705
+ intervals = focus_node.parent.intervals
706
+
707
+ # slider interval is 0-1000 not 0-1:
708
+ slider_interval = [x/1000 for x in slider_interval]
709
+ # get part of intervals that intersect slider interval:
710
+ intervals = interval_intersect([slider_interval], intervals)
711
+
712
+ for interval in intervals:
713
+ # get marginal arg under focus node
714
+ new_nodes = traverse_marginal(focus_node, interval)
715
+ new_nodes = list(new_nodes)
716
+ new_nodes.sort(key=lambda x: x.height)
717
+
718
+ marg_trees, _ = marginal_trees(new_nodes, interval)
719
+ marg_tree_list.extend(marg_trees)
720
+
721
+ nr_cols = len(marg_tree_list)
722
+
723
+ space = 0.5
724
+
725
+ for i in range(nr_cols):
726
+ tree = marg_tree_list[i]
727
+ for node in tree:
728
+ node.xpos = node.xpos/(nr_cols+(nr_cols-1)*space) + i/nr_cols
729
+ marg_tree_list[i] = tree
730
+
731
+ # TODO: Maybe keep "dangling root" branch here
732
+
733
+ if marg_tree_list:
734
+ return(tree_figure_data(marg_tree_list))
735
+ else:
736
+ return(tree_figure_data([]))
737
+
738
+
739
+ @app.callback(
740
+ Output('ancestral-sequence', 'figure'),
741
+ [Input('intermediate-value', 'children'),
742
+ Input('arg-figure', 'hoverData'),
743
+ Input('seq-slider', 'value')])
744
+ def update_ancestral_seq_figure(jsonified_data, hover, slider_interval):
745
+
746
+ traces = []
747
+ shape_list = []
748
+
749
+ if hover and jsonified_data:
750
+ nodes = json2arg(jsonified_data)
751
+ focus_node_idx = hover['points'][0]['pointIndex']
752
+ focus_node = nodes[focus_node_idx]
753
+
754
+ # slider interval is 0-1000 not 0-1:
755
+ slider_interval = [x/1000 for x in slider_interval]
756
+
757
+ gray_segments = list(map(tuple, interval_diff([[0, 1]], [slider_interval])))
758
+
759
+ def get_segments(focus_node, intervals):
760
+ segments = list()
761
+ marg_tree_list = list()
762
+ for interval in intervals:
763
+ new_nodes = traverse_marginal(focus_node, interval)
764
+ new_nodes = list(new_nodes)
765
+ new_nodes.sort(key=lambda x: x.height)
766
+ marg_trees, marg_segm = marginal_trees(new_nodes, interval)
767
+ # print(marg_trees, marg_segm)
768
+ marg_tree_list.extend(marg_trees)
769
+ segments.extend(marg_segm)
770
+ return segments
771
+
772
+ def get_shapes(segments, gray_segments, x, y, color_map):
773
+ shape_list = list()
774
+ shape = dict(type='rect', xref='x', yref='y', fillcolor='white', line= {'width': 1},
775
+ x0=x, y0=y, x1=x+2/5, y1=y+0.1)
776
+ shape_list.append(shape)
777
+ for i, segment in enumerate(segments):
778
+ color=color_map[segment]
779
+ shape = dict(type='rect', xref='x', yref='y', fillcolor=color, line= {'width': 1},
780
+ x0=x+segment[0]*2/5, y0=y, x1=x+segment[1]*2/5, y1=y+0.1)
781
+ shape_list.append(shape)
782
+ for i, segment in enumerate(gray_segments):
783
+ shape = dict(type='rect', xref='x', yref='y', fillcolor='lightgray', line= {'width': 1},
784
+ x0=x+segment[0]*2/5, y0=y, x1=x+segment[1]*2/5, y1=y+0.1)
785
+ shape_list.append(shape)
786
+ return shape_list
787
+
788
+ if type(focus_node) is Leaf:
789
+
790
+ colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Rainbow)
791
+ colorscale = plotly.colors.make_colorscale(colors)
792
+ color = get_continuous_color(colorscale, intermed=0)
793
+
794
+ shape_list = [dict(type='rect', xref='x', yref='y', fillcolor=color, line= {'width': 1},
795
+ x0=1.5/5, y0=0.25, x1=3.5/5, y1=0.35)]
796
+ for segment in gray_segments:
797
+ shape = dict(type='rect', xref='x', yref='y', fillcolor='lightgray', line= {'width': 1},
798
+ x0=1.5/5+segment[0]*2/5, y0=0.25, x1=1.5/5+segment[1]*2/5, y1=0.35)
799
+ shape_list.append(shape)
800
+
801
+ elif type(focus_node) is Recombination:
802
+ # print("###", focus_node.left_parent.intervals, focus_node.right_parent.intervals)
803
+ segments1 = get_segments(focus_node, focus_node.left_parent.intervals)
804
+ segments2 = get_segments(focus_node, focus_node.right_parent.intervals)
805
+ segments3 = get_segments(focus_node, focus_node.child.intervals)
806
+
807
+ # get part of intervals that intersect slider interval:
808
+ segments1 = list(map(tuple, interval_intersect([slider_interval], segments1)))
809
+ segments2 = list(map(tuple, interval_intersect([slider_interval], segments2)))
810
+ segments3 = list(map(tuple, interval_intersect([slider_interval], segments3)))
811
+
812
+ unique_segments = sorted(set().union(segments1, segments2, segments3))
813
+ color_map = dict()
814
+ colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Rainbow)
815
+ colorscale = plotly.colors.make_colorscale(colors)
816
+ for i, s in enumerate(unique_segments):
817
+ color_map[s] = get_continuous_color(colorscale, intermed=i/len(unique_segments))
818
+
819
+ shape_list = \
820
+ [dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
821
+ x0=0.5, y0=0.55, x1=0.5, y1=0.1),
822
+ dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
823
+ x0=0.5, y0=0.55, x1=0.05, y1=0.95),
824
+ # x0=0.5, y0=0.55, x1=1/5, y1=0.75),
825
+ dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
826
+ x0=0.5, y0=0.55, x1=0.95, y1=0.95),
827
+ dict(type="circle", xref="x", yref="y",
828
+ fillcolor="black",
829
+ x0=0.488, y0=0.5285, x1=0.512, y1=0.5715,
830
+ line_color="black")
831
+ ] + \
832
+ get_shapes(segments1, gray_segments, x=0, y=0.75, color_map=color_map) + \
833
+ get_shapes(segments2, gray_segments, x=3/5, y=0.75, color_map=color_map) + \
834
+ get_shapes(segments3, gray_segments, x=1.5/5, y=0.25, color_map=color_map)
835
+
836
+ # traces.append(dict(
837
+ # x=[0.5, 0.5],
838
+ # y=[0.55, 0.55],
839
+ # mode='markers',
840
+ # # opacity=1,
841
+ # marker={
842
+ # 'size': 7,
843
+ # 'color': 'black',},
844
+ # name=''
845
+ # ))
846
+
847
+
848
+ # get_shapes(segments1, gray_segments, x=0, y=0.75, color_map=color_map) + \
849
+ # get_shapes(segments2, gray_segments, x=3/5, y=0.75, color_map=color_map) + \
850
+ # get_shapes(segments3, gray_segments, x=1.5/5, y=0.25, color_map=color_map) + \
851
+ # [dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
852
+ # x0=0.5, y0=0.55, x1=0.5, y1=0.35),
853
+ # dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
854
+ # x0=0.5, y0=0.55, x1=0, y1=1),
855
+ # # x0=0.5, y0=0.55, x1=1/5, y1=0.75),
856
+ # dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
857
+ # x0=0.5, y0=0.55, x1=4/5, y1=0.75)]
858
+
859
+
860
+ # print("slider", slider_interval)
861
+ # shape_list.append(dict(type='rect', xref='x', yref='y', fillcolor='grey', line= {'width': 1},
862
+ # x0=slider_interval[0], y0=0.5, x1=slider_interval[1], y1=0.5+0.1))
863
+
864
+ else:
865
+ segments1 = get_segments(focus_node, focus_node.children[0].intervals)
866
+ segments2 = get_segments(focus_node, focus_node.children[1].intervals)
867
+ segments3 = get_segments(focus_node, focus_node.parent.intervals)
868
+
869
+ # get part of intervals that intersect slider interval:
870
+ segments1 = list(map(tuple, interval_intersect([slider_interval], segments1)))
871
+ segments2 = list(map(tuple, interval_intersect([slider_interval], segments2)))
872
+ segments3 = list(map(tuple, interval_intersect([slider_interval], segments3)))
873
+
874
+ unique_segments = sorted(set().union(segments1, segments2, segments3))
875
+ color_map = dict()
876
+ colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Rainbow)
877
+ colorscale = plotly.colors.make_colorscale(colors)
878
+ for i, s in enumerate(unique_segments):
879
+ color_map[s] = get_continuous_color(colorscale, intermed=i/len(unique_segments))
880
+
881
+ shape_list = \
882
+ get_shapes(segments1, gray_segments, x=0, y=0.25, color_map=color_map) + \
883
+ get_shapes(segments2, gray_segments, x=3/5, y=0.25, color_map=color_map) + \
884
+ get_shapes(segments3, gray_segments, x=1.5/5, y=0.75, color_map=color_map) + \
885
+ [dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
886
+ x0=0.5, y0=0.55, x1=0.5, y1=0.75),
887
+ dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
888
+ x0=0.5, y0=0.55, x1=1/5, y1=0.35),
889
+ dict(type='line', xref='x', yref='y', line= {'width': 2, 'color': 'gray'},
890
+ x0=0.5, y0=0.55, x1=4/5, y1=0.35)]
891
+
892
+ figure_data = dict(
893
+ data=traces,
894
+ layout=dict(xaxis=dict(fixedrange=True,
895
+ range=[-0.01, 1.01], #title='Samples',
896
+ showgrid=False, showline=False,
897
+ zeroline=False, showticklabels=False
898
+ ),
899
+ yaxis=dict(fixedrange=True,
900
+ range=[0, 1], #title='Time',
901
+ showgrid=False, showline=False,
902
+ zeroline=False, showticklabels=False
903
+ ),
904
+ margin= {'l': 0, 'b': 0, 't': 20, 'r': 0},
905
+ transition = {'duration': 0},
906
+ showlegend=False,
907
+ shapes=shape_list,
908
+ )
909
+ )
910
+
911
+ # figure_data['layout']['shapes'].extend(shape_list)
912
+
913
+ return figure_data
914
+
915
+
916
+
917
+ app.layout = layout
918
+
919
+ def run():
920
+
921
+ # parse command line arguments
922
+ import argparse
923
+ parser = argparse.ArgumentParser(description='Run ARG dashboard')
924
+ parser.add_argument('--host', '--ip', default='127.0.0.1', help='Host address')
925
+ parser.add_argument('--port', type=int, default=8050, help='Port number')
926
+ parser.add_argument('--debug', action='store_true', help='Run in debug mode')
927
+ args = parser.parse_args()
928
+
929
+ import webbrowser
930
+ webbrowser.open(f'http://{args.host}:{args.port}/')
931
+ app.run(host=args.host, port=args.port, debug=args.debug)
932
+
933
+ if __name__ == '__main__':
934
+ app.run_server(debug=True)