pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1463 @@
1
+ import ast
2
+ import os
3
+ from datetime import datetime
4
+ import math
5
+ import numpy as np
6
+ import time
7
+ import threading
8
+ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
9
+ import warnings
10
+ import webbrowser
11
+
12
+ # prevent endless console prints
13
+ import logging
14
+ log = logging.getLogger('werkzeug')
15
+ log.setLevel(logging.ERROR)
16
+
17
+ import dash
18
+ from dash.dcc import Interval, Graph, Store
19
+ from dash.dependencies import Input, Output, State, ALL
20
+ from dash.html import Div, B, H4, P, Img, Hr
21
+ import dash_bootstrap_components as dbc
22
+
23
+ import plotly.colors as pc
24
+ import plotly.graph_objs as go
25
+ from plotly.subplots import make_subplots
26
+
27
+ from pyRDDLGym.core.debug.decompiler import RDDLDecompiler
28
+
29
+ if TYPE_CHECKING:
30
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner
31
+
32
+ POLICY_DIST_HEIGHT = 400
33
+ POLICY_DIST_PLOTS_PER_ROW = 6
34
+ ACTION_HEATMAP_HEIGHT = 400
35
+ PROGRESS_FOR_NEXT_RETURN_DIST = 2
36
+ PROGRESS_FOR_NEXT_POLICY_DIST = 10
37
+ REWARD_ERROR_DIST_SUBPLOTS = 20
38
+ MODEL_STATE_ERROR_HEIGHT = 300
39
+ POLICY_STATE_VIZ_MAX_HEIGHT = 800
40
+ GP_POSTERIOR_MAX_HEIGHT = 800
41
+
42
+ PLOT_AXES_FONT_SIZE = 11
43
+ EXPERIMENT_ENTRY_FONT_SIZE = 14
44
+
45
+
46
+ class JaxPlannerDashboard:
47
+ '''A dashboard app for monitoring the jax planner progress.'''
48
+
49
+ def __init__(self, theme: str=dbc.themes.CERULEAN) -> None:
50
+
51
+ self.timestamps = {}
52
+ self.duration = {}
53
+ self.seeds = {}
54
+ self.status = {}
55
+ self.warnings = []
56
+ self.progress = {}
57
+ self.checked = {}
58
+ self.rddl = {}
59
+ self.planner_info = {}
60
+
61
+ self.xticks = {}
62
+ self.test_return = {}
63
+ self.train_return = {}
64
+ self.return_dist = {}
65
+ self.return_dist_ticks = {}
66
+ self.return_dist_last_progress = {}
67
+ self.action_output = {}
68
+ self.policy_params = {}
69
+ self.policy_params_ticks = {}
70
+ self.policy_params_last_progress = {}
71
+ self.policy_viz = {}
72
+
73
+ self.relaxed_exprs = {}
74
+ self.relaxed_exprs_values = {}
75
+ self.train_reward_dist = {}
76
+ self.test_reward_dist = {}
77
+ self.train_state_fluents = {}
78
+ self.test_state_fluents = {}
79
+
80
+ self.tuning_gp_heatmaps = None
81
+ self.tuning_gp_targets = None
82
+ self.tuning_gp_predicted = None
83
+ self.tuning_gp_params = None
84
+ self.tuning_gp_update = False
85
+
86
+ # ======================================================================
87
+ # CREATE PAGE LAYOUT
88
+ # ======================================================================
89
+
90
+ def create_experiment_table(active_page, page_size):
91
+ start = (active_page - 1) * page_size
92
+ end = start + page_size
93
+ rows = []
94
+
95
+ # header
96
+ row = dbc.Row([
97
+ dbc.Col([
98
+ dbc.Card(dbc.CardBody(
99
+ B('Display'), style={"padding": "0"}
100
+ ), className="border-0 bg-transparent")
101
+ ], width=1),
102
+ dbc.Col([
103
+ dbc.Card(dbc.CardBody(
104
+ B('Experiment ID'), style={"padding": "0"}
105
+ ), className="border-0 bg-transparent"),
106
+ ], width=2),
107
+ dbc.Col([
108
+ dbc.Card(dbc.CardBody(
109
+ B('Seed'), style={"padding": "0"}
110
+ ), className="border-0 bg-transparent")
111
+ ], width=1),
112
+ dbc.Col([
113
+ dbc.Card(dbc.CardBody(
114
+ B('Timestamp'), style={"padding": "0"}
115
+ ), className="border-0 bg-transparent")
116
+ ], width=2),
117
+ dbc.Col([
118
+ dbc.Card(dbc.CardBody(
119
+ B('Duration'), style={"padding": "0"}
120
+ ), className="border-0 bg-transparent")
121
+ ], width=1),
122
+ dbc.Col([
123
+ dbc.Card(dbc.CardBody(
124
+ B('Best Return'), style={"padding": "0"}
125
+ ), className="border-0 bg-transparent")
126
+ ], width=2),
127
+ dbc.Col([
128
+ dbc.Card(dbc.CardBody(
129
+ B('Status'), style={"padding": "0"}
130
+ ), className="border-0 bg-transparent")
131
+ ], width=2),
132
+ dbc.Col([
133
+ dbc.Card(dbc.CardBody(
134
+ B('Progress'), style={"padding": "0"}
135
+ ), className="border-0 bg-transparent")
136
+ ], width=1)
137
+ ])
138
+ rows.append(row)
139
+
140
+ # experiments
141
+ for (i, id) in enumerate(self.checked):
142
+ if i >= end:
143
+ break
144
+ if i >= start and i < end:
145
+ progress = self.progress[id]
146
+ row = dbc.Row([
147
+ dbc.Col([
148
+ dbc.Card(
149
+ dbc.CardBody([
150
+ dbc.Checkbox(
151
+ id={'type': 'experiment-checkbox', 'index': id},
152
+ value=self.checked[id]
153
+ )],
154
+ style={"padding": "0"}
155
+ ),
156
+ className="border-0 bg-transparent"
157
+ )
158
+ ], width=1),
159
+ dbc.Col([
160
+ dbc.Card(
161
+ dbc.CardBody(id, style={"padding": "0"}),
162
+ className="border-0 bg-transparent"
163
+ )
164
+ ], width=2),
165
+ dbc.Col([
166
+ dbc.Card(
167
+ dbc.CardBody(self.seeds[id], style={"padding": "0"}),
168
+ className="border-0 bg-transparent"
169
+ )
170
+ ], width=1),
171
+ dbc.Col([
172
+ dbc.Card(
173
+ dbc.CardBody(self.timestamps[id], style={"padding": "0"}),
174
+ className="border-0 bg-transparent"
175
+ ),
176
+ ], width=2),
177
+ dbc.Col([
178
+ dbc.Card(
179
+ dbc.CardBody(f'{self.duration[id]:.3f}s', style={"padding": "0"}),
180
+ className="border-0 bg-transparent"
181
+ ),
182
+ ], width=1),
183
+ dbc.Col([
184
+ dbc.Card(
185
+ dbc.CardBody(f'{(self.test_return[id] or [np.nan])[-1]:.3f}',
186
+ style={"padding": "0"}),
187
+ className="border-0 bg-transparent"
188
+ ),
189
+ ], width=2),
190
+ dbc.Col([
191
+ dbc.Card(
192
+ dbc.CardBody(self.status[id], style={"padding": "0"}),
193
+ className="border-0 bg-transparent"
194
+ ),
195
+ ], width=2),
196
+ dbc.Col([
197
+ dbc.Card(
198
+ dbc.CardBody(
199
+ [dbc.Progress(label=f"{progress}%", value=progress)],
200
+ style={"padding": "0"}
201
+ ),
202
+ className="border-0 bg-transparent"
203
+ ),
204
+ ], width=1),
205
+ ])
206
+ rows.append(row)
207
+ return rows
208
+
209
+ app = dash.Dash(__name__, external_stylesheets=[theme])
210
+ app.title = 'JaxPlan Dashboard'
211
+
212
+ app.layout = dbc.Container([
213
+ Store(id='refresh-interval', data=2000),
214
+ Store(id='experiment-num-per-page', data=10),
215
+ Store(id='model-params-dropdown-expr', data=''),
216
+ Store(id='model-errors-state-dropdown-selected', data=''),
217
+ Store(id='viz-skip-frequency', data=5),
218
+ Store(id='viz-num-trajectories', data=3),
219
+ Div(id='viewport-sizer', style={'display': 'none'}),
220
+
221
+ # navbar
222
+ dbc.Navbar(
223
+ dbc.Container([
224
+ # Img(src=LOGO_FILE, height="30px", style={'margin-right': '10px'}),
225
+ dbc.NavbarBrand(f"JaxPlan Dashboard"),
226
+ dbc.Nav([
227
+ dbc.NavItem(
228
+ dbc.NavLink(
229
+ "Docs",
230
+ href="https://pyrddlgym.readthedocs.io/en/latest/jax.html"
231
+ )
232
+ ),
233
+ dbc.NavItem(
234
+ dbc.NavLink(
235
+ "GitHub",
236
+ href="https://github.com/pyrddlgym-project/pyRDDLGym-jax"
237
+ )
238
+ ),
239
+ dbc.NavItem(
240
+ dbc.NavLink(
241
+ "Submit an Issue",
242
+ href="https://github.com/pyrddlgym-project/pyRDDLGym-jax/issues"
243
+ )
244
+ )
245
+ ], navbar=True, className="me-auto"),
246
+ dbc.Nav([
247
+ dbc.DropdownMenu(
248
+ [dbc.DropdownMenuItem("500ms", id='05sec'),
249
+ dbc.DropdownMenuItem("1s", id='1sec'),
250
+ dbc.DropdownMenuItem("2s", id='2sec'),
251
+ dbc.DropdownMenuItem("5s", id='5sec'),
252
+ dbc.DropdownMenuItem("10s", id='10sec'),
253
+ dbc.DropdownMenuItem("30s", id='30sec'),
254
+ dbc.DropdownMenuItem("1m", id='1min'),
255
+ dbc.DropdownMenuItem("5m", id='5min'),
256
+ dbc.DropdownMenuItem("1d", id='1day')],
257
+ label="Refresh: 2s",
258
+ id='refresh-rate-dropdown',
259
+ nav=True
260
+ )
261
+ ], navbar=True),
262
+ dbc.Nav([
263
+ dbc.DropdownMenu(
264
+ [dbc.DropdownMenuItem("5", id='5pp'),
265
+ dbc.DropdownMenuItem("10", id='10pp'),
266
+ dbc.DropdownMenuItem("25", id='25pp'),
267
+ dbc.DropdownMenuItem("50", id='50pp')],
268
+ label="Exp. Per Page: 10",
269
+ id='experiment-num-per-page-dropdown',
270
+ nav=True
271
+ )
272
+ ], navbar=True)
273
+ ], fluid=True)
274
+ ),
275
+
276
+ # experiments
277
+ dbc.Row([
278
+ dbc.Col([
279
+ dbc.Card([
280
+ dbc.CardBody([
281
+ Div(create_experiment_table(0, 10), id='experiment-table',
282
+ style={'fontSize': f'{EXPERIMENT_ENTRY_FONT_SIZE}px'}),
283
+ dbc.Pagination(id='experiment-pagination',
284
+ active_page=1, max_value=1, size="sm")
285
+ ], style={'padding': '10px'})
286
+ ], className="border-0 bg-transparent")
287
+ ])
288
+ ]),
289
+
290
+ # empirical results tabs
291
+ dbc.Row([
292
+ dbc.Col([
293
+ dbc.Tabs([
294
+
295
+ # returns
296
+ dbc.Tab(dbc.Card(
297
+ dbc.CardBody([
298
+ dbc.Row([
299
+ dbc.Col(Graph(id='train-return-graph'), width=6),
300
+ dbc.Col(Graph(id='test-return-graph'), width=6),
301
+ ]),
302
+ dbc.Row([
303
+ Graph(id='dist-return-graph')
304
+ ])
305
+ ]), className="border-0 bg-transparent"
306
+ ), label="Performance", tab_id='tab-performance'
307
+ ),
308
+
309
+ # policy
310
+ dbc.Tab(dbc.Card(
311
+ dbc.CardBody([
312
+ dbc.Row([
313
+ Graph(id='action-output'),
314
+ ]),
315
+ dbc.Row([
316
+ dbc.Col([
317
+ Hr(className='my-4')
318
+ ])
319
+ ]),
320
+ dbc.Row([
321
+ Graph(id='policy-params'),
322
+ ]),
323
+ dbc.Row([
324
+ dbc.Col([
325
+ Hr(className='my-4')
326
+ ])
327
+ ]),
328
+ dbc.Row([
329
+ dbc.Col([
330
+ dbc.Row([
331
+ dbc.Col([
332
+ dbc.DropdownMenu(
333
+ [dbc.DropdownMenuItem('Every Frame', id='viz-skip-1'),
334
+ dbc.DropdownMenuItem('Every 2 Frames', id='viz-skip-2'),
335
+ dbc.DropdownMenuItem('Every 3 Frames', id='viz-skip-3'),
336
+ dbc.DropdownMenuItem('Every 4 Frames', id='viz-skip-4'),
337
+ dbc.DropdownMenuItem('Every 5 Frames', id='viz-skip-5'),
338
+ dbc.DropdownMenuItem('Every 10 Frames', id='viz-skip-10')],
339
+ label="Render: Every 5 Frames",
340
+ id='viz-skip-dropdown'
341
+ ),
342
+ ], width='auto'),
343
+ dbc.Col([
344
+ dbc.DropdownMenu(
345
+ [dbc.DropdownMenuItem('1', id='viz-num-1'),
346
+ dbc.DropdownMenuItem('2', id='viz-num-2'),
347
+ dbc.DropdownMenuItem('3', id='viz-num-3'),
348
+ dbc.DropdownMenuItem('4', id='viz-num-4'),
349
+ dbc.DropdownMenuItem('5', id='viz-num-5')],
350
+ label="Max. Trajectories: 3",
351
+ id='viz-num-dropdown'
352
+ ),
353
+ ], width='auto'),
354
+ dbc.Col([
355
+ dbc.Button('Run Policy Visualization',
356
+ id='policy-viz-button'),
357
+ ], width='auto')
358
+ ]),
359
+ dbc.Row([
360
+ Graph(id='policy-viz')
361
+ ])
362
+ ])
363
+ ]),
364
+ ]), className="border-0 bg-transparent"
365
+ ), label="Policy", tab_id='tab-policy'
366
+ ),
367
+
368
+ # model
369
+ dbc.Tab(dbc.Card(
370
+ dbc.CardBody([
371
+ dbc.Row([
372
+ dbc.Col([
373
+ dbc.Card([
374
+ dbc.DropdownMenu(
375
+ [],
376
+ label="RDDL Expression",
377
+ id='model-params-dropdown'
378
+ ),
379
+ Graph(id='model-params-graph')
380
+ ], className="border-0 bg-transparent"
381
+ ),
382
+ ])
383
+ ]),
384
+ dbc.Row([
385
+ dbc.Col([
386
+ Hr(className='my-4')
387
+ ])
388
+ ]),
389
+ dbc.Row([
390
+ dbc.Col([
391
+ dbc.Card(
392
+ dbc.CardBody(
393
+ Graph(id='model-errors-reward-graph')
394
+ ),
395
+ className="border-0 bg-transparent"
396
+ ),
397
+ ])
398
+ ]),
399
+ dbc.Row([
400
+ dbc.Col([
401
+ dbc.Card([
402
+ dbc.DropdownMenu(
403
+ [],
404
+ label="State-Fluent",
405
+ id='model-errors-state-dropdown'
406
+ ),
407
+ Graph(id='model-errors-state-graph')
408
+ ], className="border-0 bg-transparent"
409
+ )
410
+ ])
411
+ ]),
412
+ ]), className="border-0 bg-transparent"
413
+ ), label="Model", tab_id='tab-model'
414
+ ),
415
+
416
+ # information
417
+ dbc.Tab(dbc.Card(
418
+ dbc.CardBody([
419
+ dbc.Row([
420
+ dbc.Alert(id="planner-info", color="light", dismissable=False,
421
+ style={"fontFamily": "Courier, monospace"})
422
+ ]),
423
+ ]), className="border-0 bg-transparent"
424
+ ), label="Debug", tab_id='tab-debug'
425
+ ),
426
+
427
+ # tuning
428
+ dbc.Tab(dbc.Card(
429
+ dbc.CardBody([
430
+ dbc.Row([
431
+ dbc.Col(Graph(id='tuning-target-graph'), width=6),
432
+ dbc.Col(Graph(id='tuning-scatter-graph'), width=6)
433
+ ]),
434
+ dbc.Row([
435
+ dbc.Col([
436
+ Hr(className='my-4')
437
+ ])
438
+ ]),
439
+ dbc.Row([
440
+ dbc.Col(Graph(id='tuning-gp-mean-graph'))
441
+ ]),
442
+ dbc.Row([
443
+ dbc.Col(Graph(id='tuning-gp-unc-graph'))
444
+ ])
445
+ ]), className="border-0 bg-transparent"
446
+ ), label="Tuning", tab_id='tab-tuning'
447
+ ),
448
+ ], id='tabs-main')
449
+ ], width=12)
450
+ ]),
451
+
452
+ # refresh interval
453
+ Interval(
454
+ id='interval',
455
+ interval=2000,
456
+ n_intervals=0
457
+ ),
458
+ Div(id='trigger-experiment-check', style={'display': 'none'})
459
+ ], fluid=True, className="dbc")
460
+
461
+ # JavaScript to retrieve the viewport dimensions
462
+ app.clientside_callback(
463
+ """
464
+ function(n_intervals) {
465
+ return {
466
+ 'height': window.innerHeight,
467
+ 'width': window.innerWidth
468
+ };
469
+ }
470
+ """,
471
+ Output('viewport-sizer', 'children'),
472
+ Input('interval', 'n_intervals')
473
+ )
474
+
475
+ # ======================================================================
476
+ # CREATE EVENTS
477
+ # ======================================================================
478
+
479
+ # modify refresh rate
480
+ @app.callback(
481
+ Output("refresh-interval", "data"),
482
+ [Input("05sec", "n_clicks"),
483
+ Input("1sec", "n_clicks"),
484
+ Input("2sec", "n_clicks"),
485
+ Input("5sec", "n_clicks"),
486
+ Input("10sec", "n_clicks"),
487
+ Input("30sec", "n_clicks"),
488
+ Input("1min", "n_clicks"),
489
+ Input("5min", "n_clicks"),
490
+ Input("1day", "n_clicks")],
491
+ [State('refresh-interval', 'data')]
492
+ )
493
+ def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, nd, data):
494
+ ctx = dash.callback_context
495
+ if not ctx.triggered:
496
+ return data
497
+ button_id = ctx.triggered[0]['prop_id'].split('.')[0]
498
+ if button_id == '05sec':
499
+ return 500
500
+ elif button_id == '1sec':
501
+ return 1000
502
+ elif button_id == '2sec':
503
+ return 2000
504
+ elif button_id == '5sec':
505
+ return 5000
506
+ elif button_id == '10sec':
507
+ return 10000
508
+ elif button_id == '30sec':
509
+ return 30000
510
+ elif button_id == '1min':
511
+ return 60000
512
+ elif button_id == '5min':
513
+ return 300000
514
+ elif button_id == '1day':
515
+ return 86400000
516
+ return data
517
+
518
+ @app.callback(
519
+ Output('interval', 'interval'),
520
+ [Input('refresh-interval', 'data')]
521
+ )
522
+ def update_refresh_rate(selected_interval):
523
+ return selected_interval if selected_interval else 2000
524
+
525
+ @app.callback(
526
+ Output('refresh-rate-dropdown', 'label'),
527
+ [Input('refresh-interval', 'data')]
528
+ )
529
+ def update_refresh_rate(selected_interval):
530
+ if selected_interval == 500:
531
+ return 'Refresh: 500ms'
532
+ elif selected_interval == 1000:
533
+ return 'Refresh: 1s'
534
+ elif selected_interval == 2000:
535
+ return 'Refresh: 2s'
536
+ elif selected_interval == 5000:
537
+ return 'Refresh: 5s'
538
+ elif selected_interval == 10000:
539
+ return 'Refresh: 10s'
540
+ elif selected_interval == 30000:
541
+ return 'Refresh: 30s'
542
+ elif selected_interval == 60000:
543
+ return 'Refresh: 1m'
544
+ elif selected_interval == 300000:
545
+ return 'Refresh: 5m'
546
+ else:
547
+ return 'Refresh: 2s'
548
+
549
+ # update the experiments per page
550
+ @app.callback(
551
+ Output("experiment-num-per-page", "data"),
552
+ [Input("5pp", "n_clicks"),
553
+ Input("10pp", "n_clicks"),
554
+ Input("25pp", "n_clicks"),
555
+ Input("50pp", "n_clicks")],
556
+ [State('experiment-num-per-page', 'data')]
557
+ )
558
+ def click_experiments_per_page(n5, n10, n25, n50, data):
559
+ ctx = dash.callback_context
560
+ if not ctx.triggered:
561
+ return data
562
+ button_id = ctx.triggered[0]['prop_id'].split('.')[0]
563
+ if button_id == '5pp':
564
+ return 5
565
+ elif button_id == '10pp':
566
+ return 10
567
+ elif button_id == '25pp':
568
+ return 25
569
+ elif button_id == '50pp':
570
+ return 50
571
+ return data
572
+
573
+ @app.callback(
574
+ Output('experiment-num-per-page-dropdown', 'label'),
575
+ [Input('experiment-num-per-page', 'data')]
576
+ )
577
+ def update_experiments_per_page(selected_num):
578
+ return f'Exp. Per Page: {selected_num}'
579
+
580
+ # update the experiment table
581
+ @app.callback(
582
+ Output('experiment-table', 'children'),
583
+ [Input('interval', 'n_intervals'),
584
+ Input('experiment-pagination', 'active_page')],
585
+ State('experiment-num-per-page', 'data')
586
+ )
587
+ def update_experiment_table(n, active_page, npp):
588
+ return create_experiment_table(active_page, npp)
589
+
590
+ @app.callback(
591
+ Output('experiment-pagination', 'max_value'),
592
+ Input('interval', 'n_intervals'),
593
+ State('experiment-num-per-page', 'data')
594
+ )
595
+ def update_experiment_max_pages(n, npp):
596
+ return (len(self.checked) + npp - 1) // npp
597
+
598
+ @app.callback(
599
+ Output('trigger-experiment-check', 'children'),
600
+ Input({'type': 'experiment-checkbox', 'index': ALL}, 'value'),
601
+ State({'type': 'experiment-checkbox', 'index': ALL}, 'id')
602
+ )
603
+ def update_checked_experiment_status(checked, ids):
604
+ for (i, chk) in enumerate(checked):
605
+ row = ids[i]['index']
606
+ self.checked[row] = chk
607
+ return time.time()
608
+
609
+ # update the return information
610
+ @app.callback(
611
+ Output('train-return-graph', 'figure'),
612
+ [Input('interval', 'n_intervals'),
613
+ Input('trigger-experiment-check', 'children'),
614
+ Input('tabs-main', 'active_tab')]
615
+ )
616
+ def update_train_return_graph(n, trigger, active_tab):
617
+ if active_tab != 'tab-performance': return dash.no_update
618
+ fig = go.Figure()
619
+ for (row, checked) in self.checked.copy().items():
620
+ if checked:
621
+ fig.add_trace(go.Scatter(
622
+ x=self.xticks[row], y=self.train_return[row],
623
+ name=f'id={row}',
624
+ mode='lines+markers',
625
+ marker=dict(size=3), line=dict(width=2)
626
+ ))
627
+ fig.update_layout(
628
+ title=dict(text="Train Return"),
629
+ xaxis=dict(title=dict(text="Training Iteration")),
630
+ yaxis=dict(title=dict(text="Cumulative Reward")),
631
+ font=dict(size=PLOT_AXES_FONT_SIZE),
632
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
633
+ template="plotly_white"
634
+ )
635
+ return fig
636
+
637
+ @app.callback(
638
+ Output('test-return-graph', 'figure'),
639
+ [Input('interval', 'n_intervals'),
640
+ Input('trigger-experiment-check', 'children'),
641
+ Input('tabs-main', 'active_tab')]
642
+ )
643
+ def update_test_return_graph(n, trigger, active_tab):
644
+ if active_tab != 'tab-performance': return dash.no_update
645
+ fig = go.Figure()
646
+ for (row, checked) in self.checked.copy().items():
647
+ if checked:
648
+ fig.add_trace(go.Scatter(
649
+ x=self.xticks[row], y=self.test_return[row],
650
+ name=f'id={row}',
651
+ mode='lines+markers',
652
+ marker=dict(size=3), line=dict(width=2)
653
+ ))
654
+ fig.update_layout(
655
+ title=dict(text="Test Return"),
656
+ xaxis=dict(title=dict(text="Training Iteration")),
657
+ yaxis=dict(title=dict(text="Cumulative Reward")),
658
+ font=dict(size=PLOT_AXES_FONT_SIZE),
659
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
660
+ template="plotly_white"
661
+ )
662
+ return fig
663
+
664
+ @app.callback(
665
+ Output('dist-return-graph', 'figure'),
666
+ [Input('interval', 'n_intervals'),
667
+ Input('trigger-experiment-check', 'children'),
668
+ Input('tabs-main', 'active_tab')]
669
+ )
670
+ def update_dist_return_graph(n, trigger, active_tab):
671
+ if active_tab != 'tab-performance': return dash.no_update
672
+ fig = go.Figure()
673
+ fig.update_layout(template='plotly_white')
674
+ for (row, checked) in self.checked.copy().items():
675
+ if checked:
676
+ return_dists = self.return_dist[row]
677
+ ticks = self.return_dist_ticks[row]
678
+ colors = pc.sample_colorscale(
679
+ pc.get_colorscale('Blues'),
680
+ np.linspace(0.1, 1, len(return_dists))
681
+ )
682
+ for (ic, (tick, dist)) in enumerate(zip(ticks, return_dists)):
683
+ fig.add_trace(go.Violin(
684
+ y=dist, line_color=colors[ic], name=f'{tick}'
685
+ ))
686
+ fig.update_traces(
687
+ orientation='v', side='positive', width=3, points=False)
688
+ fig.update_layout(
689
+ title=dict(text="Distribution of Return"),
690
+ xaxis=dict(title=dict(text="Training Iteration")),
691
+ yaxis=dict(title=dict(text="Cumulative Reward")),
692
+ font=dict(size=PLOT_AXES_FONT_SIZE),
693
+ showlegend=False,
694
+ yaxis_showgrid=False, yaxis_zeroline=False,
695
+ template="plotly_white"
696
+ )
697
+ break
698
+ return fig
699
+
700
+ # update the action heatmap
701
+ @app.callback(
702
+ Output('action-output', 'figure'),
703
+ [Input('interval', 'n_intervals'),
704
+ Input('trigger-experiment-check', 'children'),
705
+ Input('tabs-main', 'active_tab')]
706
+ )
707
+ def update_action_heatmap(n, trigger, active_tab):
708
+ if active_tab != 'tab-policy': return dash.no_update
709
+ fig = go.Figure()
710
+ fig.update_layout(template='plotly_white')
711
+ for (row, checked) in self.checked.copy().items():
712
+ if checked and self.action_output[row] is not None:
713
+ num_plots = len(self.action_output[row])
714
+ titles = []
715
+ for (_, act, _) in self.action_output[row]:
716
+ titles.append(f'Values of Action-Fluents {act}')
717
+ titles.append(f'Std. Dev. of Action-Fluents {act}')
718
+ fig = make_subplots(
719
+ rows=num_plots, cols=2,
720
+ shared_xaxes=True, horizontal_spacing=0.15,
721
+ subplot_titles=titles
722
+ )
723
+ for (i, (action_output, action, action_labels)) \
724
+ in enumerate(self.action_output[row]):
725
+ action_values = np.mean(1. * action_output, axis=0).T
726
+ action_errors = np.std(1. * action_output, axis=0).T
727
+ fig.add_trace(go.Heatmap(
728
+ z=action_values,
729
+ x=np.arange(action_values.shape[1]),
730
+ y=np.arange(action_values.shape[0]),
731
+ colorscale='Blues', colorbar_x=0.45,
732
+ colorbar_len=0.8 / num_plots,
733
+ colorbar_y=1 - (i + 0.5) / num_plots
734
+ ), row=i + 1, col=1)
735
+ fig.add_trace(go.Heatmap(
736
+ z=action_errors,
737
+ x=np.arange(action_errors.shape[1]),
738
+ y=np.arange(action_errors.shape[0]),
739
+ colorscale='Reds', colorbar_len=0.8 / num_plots,
740
+ colorbar_y=1 - (i + 0.5) / num_plots
741
+ ), row=i + 1, col=2)
742
+ fig.update_layout(
743
+ title="Values of Action-Fluents",
744
+ xaxis=dict(title=dict(text="Decision Epoch")),
745
+ font=dict(size=PLOT_AXES_FONT_SIZE),
746
+ height=ACTION_HEATMAP_HEIGHT * num_plots,
747
+ showlegend=False,
748
+ template="plotly_white"
749
+ )
750
+ break
751
+ return fig
752
+
753
+ # update the weight histograms
754
+ @app.callback(
755
+ Output('policy-params', 'figure'),
756
+ [Input('interval', 'n_intervals'),
757
+ Input('trigger-experiment-check', 'children'),
758
+ Input('tabs-main', 'active_tab')]
759
+ )
760
+ def update_policy_params(n, trigger, active_tab):
761
+ if active_tab != 'tab-policy': return dash.no_update
762
+ fig = go.Figure()
763
+ fig.update_layout(template='plotly_white')
764
+ for (row, checked) in self.checked.copy().items():
765
+ policy_params = self.policy_params[row]
766
+ policy_params_ticks = self.policy_params_ticks[row]
767
+ if checked and policy_params is not None and policy_params:
768
+ titles = []
769
+ for (layer_name, layer_params) in policy_params[0].items():
770
+ if isinstance(layer_params, dict):
771
+ for weight_name in layer_params:
772
+ titles.append(f'{layer_name}/{weight_name}')
773
+
774
+ n_rows = math.ceil(len(policy_params) / POLICY_DIST_PLOTS_PER_ROW)
775
+ fig = make_subplots(
776
+ rows=n_rows, cols=POLICY_DIST_PLOTS_PER_ROW,
777
+ shared_xaxes=True,
778
+ subplot_titles=titles
779
+ )
780
+ colors = pc.sample_colorscale(
781
+ pc.get_colorscale('Blues'),
782
+ np.linspace(0.1, 1, len(policy_params))
783
+ )[::-1]
784
+
785
+ for (it, (tick, policy_params_t)) in enumerate(
786
+ zip(policy_params_ticks[::-1], policy_params[::-1])):
787
+ r, c = 1, 1
788
+ for (layer_name, layer_params) in policy_params_t.items():
789
+ if isinstance(layer_params, dict):
790
+ for (weight_name, weight_values) in layer_params.items():
791
+ if r <= n_rows:
792
+ fig.add_trace(go.Violin(
793
+ x=np.ravel(weight_values),
794
+ line_color=colors[it], name=f'{tick}'
795
+ ), row=r, col=c)
796
+ c += 1
797
+ if c > POLICY_DIST_PLOTS_PER_ROW:
798
+ r += 1
799
+ c = 1
800
+ fig.update_traces(
801
+ orientation='h', side='positive', width=3, points=False)
802
+ fig.update_layout(
803
+ title="Distribution of Network Weight Parameters",
804
+ font=dict(size=PLOT_AXES_FONT_SIZE),
805
+ showlegend=False,
806
+ xaxis_showgrid=False, xaxis_zeroline=False,
807
+ height=POLICY_DIST_HEIGHT * n_rows,
808
+ template="plotly_white"
809
+ )
810
+ break
811
+ return fig
812
+
813
+ # modify viz skip rate
814
+ @app.callback(
815
+ Output("viz-skip-frequency", "data"),
816
+ [Input("viz-skip-1", "n_clicks"),
817
+ Input("viz-skip-2", "n_clicks"),
818
+ Input("viz-skip-3", "n_clicks"),
819
+ Input("viz-skip-4", "n_clicks"),
820
+ Input("viz-skip-5", "n_clicks"),
821
+ Input("viz-skip-10", "n_clicks")],
822
+ [State('viz-skip-frequency', 'data')]
823
+ )
824
+ def click_viz_skip_rate(v1, v2, v3, v4, v5, v10, data):
825
+ ctx = dash.callback_context
826
+ if not ctx.triggered:
827
+ return data
828
+ button_id = ctx.triggered[0]['prop_id'].split('.')[0]
829
+ if button_id == 'viz-skip-1':
830
+ return 1
831
+ elif button_id == 'viz-skip-2':
832
+ return 2
833
+ elif button_id == 'viz-skip-3':
834
+ return 3
835
+ elif button_id == 'viz-skip-4':
836
+ return 4
837
+ elif button_id == 'viz-skip-5':
838
+ return 5
839
+ elif button_id == 'viz-skip-10':
840
+ return 10
841
+ return data
842
+
843
+ @app.callback(
844
+ Output('viz-skip-dropdown', 'label'),
845
+ [Input('viz-skip-frequency', 'data')]
846
+ )
847
+ def update_viz_skip_dropdown_text(viz_skip):
848
+ if viz_skip == 1:
849
+ return 'Render: Every Frame'
850
+ else:
851
+ return f'Render: Every {viz_skip} Frames'
852
+
853
+ # modify viz count
854
+ @app.callback(
855
+ Output("viz-num-trajectories", "data"),
856
+ [Input("viz-num-1", "n_clicks"),
857
+ Input("viz-num-2", "n_clicks"),
858
+ Input("viz-num-3", "n_clicks"),
859
+ Input("viz-num-4", "n_clicks"),
860
+ Input("viz-num-5", "n_clicks")],
861
+ [State('viz-num-trajectories', 'data')]
862
+ )
863
+ def click_viz_num_render(v1, v2, v3, v4, v5, data):
864
+ ctx = dash.callback_context
865
+ if not ctx.triggered:
866
+ return data
867
+ button_id = ctx.triggered[0]['prop_id'].split('.')[0]
868
+ if button_id == 'viz-num-1':
869
+ return 1
870
+ elif button_id == 'viz-num-2':
871
+ return 2
872
+ elif button_id == 'viz-num-3':
873
+ return 3
874
+ elif button_id == 'viz-num-4':
875
+ return 4
876
+ elif button_id == 'viz-num-5':
877
+ return 5
878
+ return data
879
+
880
+ @app.callback(
881
+ Output('viz-num-dropdown', 'label'),
882
+ [Input('viz-num-trajectories', 'data')]
883
+ )
884
+ def update_viz_num_dropdown_text(viz_num):
885
+ return f'Max. Trajectories: {viz_num}'
886
+
887
+ # update the policy viz
888
+ @app.callback(
889
+ Output('policy-viz', 'figure'),
890
+ Input("policy-viz-button", "n_clicks"),
891
+ [State('viewport-sizer', 'children'),
892
+ State("viz-skip-frequency", "data"),
893
+ State("viz-num-trajectories", "data")]
894
+ )
895
+ def update_policy_viz(n_clicks, viewport_size, skip_freq, viz_num):
896
+ if not viewport_size: return dash.no_update
897
+ if not n_clicks: return dash.no_update
898
+
899
+ for (row, checked) in self.checked.copy().items():
900
+ viz = self.policy_viz[row]
901
+ if checked and viz is not None:
902
+ states = self.train_state_fluents[row]
903
+ lookahead = next(iter(states.values())).shape[1]
904
+ batch_idx = self.representative_trajectories(states, k=viz_num)
905
+ policy_viz_frames = []
906
+ for idx in batch_idx:
907
+ avg_image = 0.
908
+ num_image = 0
909
+ viz.__init__(self.rddl[row])
910
+ for t in range(0, lookahead, skip_freq):
911
+ state_t = {name: values[idx, t]
912
+ for (name, values) in states.items()}
913
+ state_t = self.rddl[row].ground_vars_with_values(state_t)
914
+ avg_image += np.asarray(viz.render(state_t))
915
+ num_image += 1
916
+ avg_image /= num_image
917
+ policy_viz_frames.append(avg_image)
918
+
919
+ subplot_width = min(
920
+ viewport_size['width'] // len(policy_viz_frames),
921
+ POLICY_STATE_VIZ_MAX_HEIGHT)
922
+ fig = make_subplots(
923
+ rows=1, cols=len(policy_viz_frames)
924
+ )
925
+ for (col, frame) in enumerate(policy_viz_frames):
926
+ fig.add_trace(go.Image(z=frame, hoverinfo='skip'),
927
+ row=1, col=1 + col)
928
+ fig.update_layout(
929
+ title="Representative Trajectories",
930
+ font=dict(size=PLOT_AXES_FONT_SIZE),
931
+ xaxis=dict(showticklabels=False),
932
+ yaxis=dict(showticklabels=False),
933
+ width=subplot_width * len(policy_viz_frames),
934
+ height=subplot_width * 1,
935
+ showlegend=False,
936
+ template="plotly_white"
937
+ )
938
+ return fig
939
+ return dash.no_update
940
+
941
+ # update the model parameter information
942
+ @app.callback(
943
+ Output('model-params-dropdown', 'children'),
944
+ [Input('trigger-experiment-check', 'children'),
945
+ Input('tabs-main', 'active_tab')]
946
+ )
947
+ def update_model_params_dropdown_create(trigger, active_tab):
948
+ if active_tab != 'tab-model': return dash.no_update
949
+ items = []
950
+ for (row, checked) in self.checked.copy().items():
951
+ if checked:
952
+ items = []
953
+ for (expr_id, expr) in self.relaxed_exprs[row].items():
954
+ items.append(dbc.DropdownMenuItem([
955
+ B(f'{expr_id}: '),
956
+ expr.replace('\n', ' ')[:120]
957
+ ], id={'type': 'expr-dropdown-item', 'index': expr_id}))
958
+ break
959
+ return items
960
+
961
+ @app.callback(
962
+ Output('model-params-dropdown-expr', 'data'),
963
+ Input({'type': 'expr-dropdown-item', 'index': ALL}, 'n_clicks')
964
+ )
965
+ def update_model_params_dropdown_select(n_clicks):
966
+ ctx = dash.callback_context
967
+ if not ctx.triggered:
968
+ return dash.no_update
969
+ if not next((item for item in n_clicks if item is not None), False):
970
+ return dash.no_update
971
+ return ast.literal_eval(
972
+ ctx.triggered[0]['prop_id'].split('.n_clicks')[0])['index']
973
+
974
+ @app.callback(
975
+ Output('model-params-graph', 'figure'),
976
+ [Input('interval', 'n_intervals'),
977
+ Input('tabs-main', 'active_tab')],
978
+ [State('model-params-dropdown-expr', 'data')]
979
+ )
980
+ def update_model_params_graph(n, active_tab, expr_id):
981
+ if active_tab != 'tab-model': return dash.no_update
982
+ fig = go.Figure()
983
+ fig.update_layout(template='plotly_white')
984
+ if expr_id == '': return fig
985
+ for (row, checked) in self.checked.copy().items():
986
+ if checked:
987
+ fig.add_trace(go.Scatter(
988
+ x=self.xticks[row],
989
+ y=self.relaxed_exprs_values[row][expr_id],
990
+ mode='lines+markers',
991
+ marker=dict(size=3), line=dict(width=2)
992
+ ))
993
+ fig.update_layout(
994
+ title=dict(text=f"Model Parameters for Expression {expr_id}"),
995
+ xaxis=dict(title=dict(text="Training Iteration")),
996
+ yaxis=dict(title=dict(text="Parameter Value")),
997
+ font=dict(size=PLOT_AXES_FONT_SIZE),
998
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
999
+ template="plotly_white"
1000
+ )
1001
+ break
1002
+ return fig
1003
+
1004
+ # update the model errors information for reward
1005
+ @app.callback(
1006
+ Output('model-errors-reward-graph', 'figure'),
1007
+ [Input('interval', 'n_intervals'),
1008
+ Input('trigger-experiment-check', 'children'),
1009
+ Input('tabs-main', 'active_tab')]
1010
+ )
1011
+ def update_model_error_reward_graph(n, trigger, active_tab):
1012
+ if active_tab != 'tab-model': return dash.no_update
1013
+ fig = go.Figure()
1014
+ fig.update_layout(template='plotly_white')
1015
+ for (row, checked) in self.checked.copy().items():
1016
+ if checked and row in self.train_reward_dist:
1017
+ data = self.train_reward_dist[row]
1018
+ num_epochs = data.shape[1]
1019
+ step = 1
1020
+ if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
1021
+ step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
1022
+ for epoch in range(0, num_epochs, step):
1023
+ fig.add_trace(go.Violin(
1024
+ y=self.train_reward_dist[row][:, epoch], x0=epoch,
1025
+ side='negative', line_color='red',
1026
+ name=f'Train Epoch {epoch + 1}'
1027
+ ))
1028
+ fig.add_trace(go.Violin(
1029
+ y=self.test_reward_dist[row][:, epoch], x0=epoch,
1030
+ side='positive', line_color='blue',
1031
+ name=f'Test Epoch {epoch + 1}'
1032
+ ))
1033
+ fig.update_traces(meanline_visible=True)
1034
+ fig.update_layout(
1035
+ title=dict(text="Distribution of Reward in Relaxed Model vs True Model"),
1036
+ xaxis=dict(title=dict(text="Decision Epoch")),
1037
+ yaxis=dict(title=dict(text="Reward")),
1038
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1039
+ violingap=0, violinmode='overlay', showlegend=False,
1040
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
1041
+ template="plotly_white"
1042
+ )
1043
+ break
1044
+ return fig
1045
+
1046
+ # update the model errors information for state
1047
+ @app.callback(
1048
+ Output('model-errors-state-dropdown', 'children'),
1049
+ [Input('trigger-experiment-check', 'children'),
1050
+ Input('tabs-main', 'active_tab')]
1051
+ )
1052
+ def update_model_errors_state_dropdown_create(trigger, active_tab):
1053
+ if active_tab != 'tab-model': return dash.no_update
1054
+ items = []
1055
+ for (row, checked) in self.checked.copy().items():
1056
+ if checked:
1057
+ items = []
1058
+ for name in self.train_state_fluents[row]:
1059
+ items.append(dbc.DropdownMenuItem(
1060
+ [name],
1061
+ id={'type': 'state-fluent-dropdown-item', 'index': name}
1062
+ ))
1063
+ break
1064
+ return items
1065
+
1066
+ @app.callback(
1067
+ Output('model-errors-state-dropdown-selected', 'data'),
1068
+ Input({'type': 'state-fluent-dropdown-item', 'index': ALL}, 'n_clicks')
1069
+ )
1070
+ def update_model_errors_state_dropdown_select(n_clicks):
1071
+ ctx = dash.callback_context
1072
+ if not ctx.triggered:
1073
+ return dash.no_update
1074
+ if not next((item for item in n_clicks if item is not None), False):
1075
+ return dash.no_update
1076
+ return ast.literal_eval(
1077
+ ctx.triggered[0]['prop_id'].split('.n_clicks')[0])['index']
1078
+
1079
+ @app.callback(
1080
+ Output('model-errors-state-graph', 'figure'),
1081
+ [Input('interval', 'n_intervals'),
1082
+ Input('trigger-experiment-check', 'children'),
1083
+ Input('tabs-main', 'active_tab')],
1084
+ [State('model-errors-state-dropdown-selected', 'data')]
1085
+ )
1086
+ def update_model_errors_state_graph(n, trigger, active_tab, state):
1087
+ if active_tab != 'tab-model': return dash.no_update
1088
+ fig = go.Figure()
1089
+ fig.update_layout(template='plotly_white')
1090
+ if not state: return fig
1091
+ for (row, checked) in self.checked.copy().items():
1092
+ if checked and row in self.train_state_fluents:
1093
+ train_values = self.train_state_fluents[row][state]
1094
+ test_values = self.test_state_fluents[row][state]
1095
+ train_values = 1 * train_values.reshape(train_values.shape[:2] + (-1,))
1096
+ test_values = 1 * test_values.reshape(test_values.shape[:2] + (-1,))
1097
+ num_epochs, num_states = train_values.shape[1:]
1098
+ step = 1
1099
+ if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
1100
+ step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
1101
+ fig = make_subplots(
1102
+ rows=num_states, cols=1, shared_xaxes=True,
1103
+ subplot_titles=self.rddl[row].variable_groundings[state]
1104
+ )
1105
+ for istate in range(num_states):
1106
+ for epoch in range(0, num_epochs, step):
1107
+ fig.add_trace(go.Violin(
1108
+ y=train_values[:, epoch, istate], x0=epoch,
1109
+ side='negative', line_color='red',
1110
+ name=f'Train Epoch {epoch + 1}'
1111
+ ), row=istate + 1, col=1)
1112
+ fig.add_trace(go.Violin(
1113
+ y=test_values[:, epoch, istate], x0=epoch,
1114
+ side='positive', line_color='blue',
1115
+ name=f'Test Epoch {epoch + 1}'
1116
+ ), row=istate + 1, col=1)
1117
+ fig.update_traces(meanline_visible=True)
1118
+ fig.update_layout(
1119
+ title=dict(text=(f"Distribution of State-Fluent {state} "
1120
+ f"in Relaxed Model vs True Model")),
1121
+ xaxis=dict(title=dict(text="Decision Epoch")),
1122
+ yaxis=dict(title=dict(text="State-Fluent Value")),
1123
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1124
+ height=MODEL_STATE_ERROR_HEIGHT * num_states,
1125
+ violingap=0, violinmode='overlay', showlegend=False,
1126
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
1127
+ template="plotly_white"
1128
+ )
1129
+ break
1130
+ return fig
1131
+
1132
+ # update the run information
1133
+ @app.callback(
1134
+ Output('planner-info', 'children'),
1135
+ [Input('interval', 'n_intervals'),
1136
+ Input('trigger-experiment-check', 'children'),
1137
+ Input('tabs-main', 'active_tab')]
1138
+ )
1139
+ def update_planner_info(n, trigger, active_tab):
1140
+ if active_tab != 'tab-debug': return dash.no_update
1141
+ result = []
1142
+ for (row, checked) in self.checked.copy().items():
1143
+ if checked:
1144
+ result = [
1145
+ H4(f'Hyper-Parameters [id={row}]', className="alert-heading"),
1146
+ P(self.planner_info[row], style={"whiteSpace": "pre-wrap"})
1147
+ ]
1148
+ break
1149
+ return result
1150
+
1151
+ # update the tuning result
1152
+ @app.callback(
1153
+ [Output('tuning-target-graph', 'figure'),
1154
+ Output('tuning-scatter-graph', 'figure'),
1155
+ Output('tuning-gp-mean-graph', 'figure'),
1156
+ Output('tuning-gp-unc-graph', 'figure')],
1157
+ [Input('interval', 'n_intervals'),
1158
+ Input('tabs-main', 'active_tab'),
1159
+ Input('viewport-sizer', 'children')]
1160
+ )
1161
+ def update_tuning_gp_graph(n, active_tab, viewport_size):
1162
+ if not self.tuning_gp_update: return dash.no_update
1163
+ if not viewport_size: return dash.no_update
1164
+
1165
+ # tuning target trend
1166
+ fig1 = go.Figure()
1167
+ fig1.add_trace(go.Scatter(
1168
+ x=np.arange(len(self.tuning_gp_targets)), y=self.tuning_gp_targets,
1169
+ mode='lines+markers',
1170
+ marker=dict(size=3), line=dict(width=2)
1171
+ ))
1172
+ fig1.update_layout(
1173
+ title=dict(text="Target Values of Trial Points"),
1174
+ xaxis=dict(title=dict(text="Trial Point")),
1175
+ yaxis=dict(title=dict(text="Target Value")),
1176
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1177
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
1178
+ template="plotly_white"
1179
+ )
1180
+
1181
+ # tuning scatter actual and predicted
1182
+ fig2 = go.Figure()
1183
+ fig2.add_trace(go.Scatter(
1184
+ x=self.tuning_gp_targets, y=self.tuning_gp_predicted,
1185
+ mode='markers', marker=dict(size=6)
1186
+ ))
1187
+ fig2.add_shape(
1188
+ type="line",
1189
+ x0=np.min(self.tuning_gp_targets),
1190
+ y0=np.min(self.tuning_gp_targets),
1191
+ x1=np.max(self.tuning_gp_targets),
1192
+ y1=np.max(self.tuning_gp_targets),
1193
+ line=dict(dash="dot", color='gray')
1194
+ )
1195
+ fig2.update_layout(
1196
+ title=dict(text="Gaussian Process Goodness-of-Fit Plot"),
1197
+ xaxis=dict(title=dict(text="Actual Target Value")),
1198
+ yaxis=dict(title=dict(text="Predicted Target Value")),
1199
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1200
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
1201
+ template="plotly_white"
1202
+ )
1203
+
1204
+ # tuning posterior plot
1205
+ num_cols = len(self.tuning_gp_heatmaps)
1206
+ num_rows = len(self.tuning_gp_heatmaps[0])
1207
+ fig3 = make_subplots(rows=num_rows, cols=num_cols)
1208
+ fig4 = make_subplots(rows=num_rows, cols=num_cols)
1209
+ for col, data_col in enumerate(self.tuning_gp_heatmaps):
1210
+ for row, data in enumerate(data_col):
1211
+ p1, p2, p1v, p2v, mean, std = data
1212
+ fig3.add_trace(go.Heatmap(
1213
+ z=mean, x=p1v, y=p2v, colorscale='Blues', showscale=False
1214
+ ), row=row + 1, col=col + 1)
1215
+ fig4.add_trace(go.Heatmap(
1216
+ z=std, x=p1v, y=p2v, colorscale='Reds', showscale=False
1217
+ ), row=row + 1, col=col + 1)
1218
+ fig3.add_trace(go.Scatter(
1219
+ x=self.tuning_gp_params[p1],
1220
+ y=self.tuning_gp_params[p2],
1221
+ opacity=1,
1222
+ mode='markers',
1223
+ marker=dict(size=5, color='green', symbol='x')
1224
+ ), row=row + 1, col=col + 1)
1225
+ fig4.add_trace(go.Scatter(
1226
+ x=self.tuning_gp_params[p1],
1227
+ y=self.tuning_gp_params[p2],
1228
+ opacity=1,
1229
+ mode='markers',
1230
+ marker=dict(size=5, color='green', symbol='x')
1231
+ ), row=row + 1, col=col + 1)
1232
+ fig3.update_xaxes(title_text=p1, row=row + 1, col=col + 1)
1233
+ fig4.update_xaxes(title_text=p1, row=row + 1, col=col + 1)
1234
+ fig3.update_yaxes(title_text=p2, row=row + 1, col=col + 1)
1235
+ fig4.update_yaxes(title_text=p2, row=row + 1, col=col + 1)
1236
+ subplot_width = min(
1237
+ GP_POSTERIOR_MAX_HEIGHT, viewport_size['width'] // num_cols)
1238
+ fig3.update_layout(
1239
+ title="Posterior Mean of Gaussian Process",
1240
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1241
+ height=subplot_width * num_rows,
1242
+ width=subplot_width * num_cols,
1243
+ autosize=False,
1244
+ showlegend=False,
1245
+ template="plotly_white"
1246
+ )
1247
+ fig4.update_layout(
1248
+ title="Posterior Uncertainty of Gaussian Process",
1249
+ font=dict(size=PLOT_AXES_FONT_SIZE),
1250
+ height=subplot_width * num_rows,
1251
+ width=subplot_width * num_cols,
1252
+ autosize=False,
1253
+ showlegend=False,
1254
+ template="plotly_white"
1255
+ )
1256
+
1257
+ self.tuning_gp_update = False
1258
+ return (fig1, fig2, fig3, fig4)
1259
+
1260
+ self.app = app
1261
+
1262
+ # ==========================================================================
1263
+ # DASHBOARD EXECUTION
1264
+ # ==========================================================================
1265
+
1266
+ def launch(self, port: int=1222, daemon: bool=True) -> None:
1267
+ '''Launches the dashboard in a browser window.'''
1268
+
1269
+ # open the browser to the required port
1270
+ if not os.environ.get("WERKZEUG_RUN_MAIN"):
1271
+ webbrowser.open_new(f'http://127.0.0.1:{port}/')
1272
+
1273
+ # run the app in a new thread at the specified port
1274
+ def run_dash():
1275
+ self.app.run(port=port)
1276
+
1277
+ dash_thread = threading.Thread(target=run_dash)
1278
+ dash_thread.daemon = daemon
1279
+ dash_thread.start()
1280
+
1281
+ @staticmethod
1282
+ def get_planner_info(planner: 'JaxBackpropPlanner') -> Dict[str, Any]:
1283
+ '''Some additional info directly from the planner that is required by
1284
+ the dashboard.'''
1285
+ return {
1286
+ 'rddl': planner.rddl,
1287
+ 'string': planner.summarize_system() + str(planner),
1288
+ 'model_parameter_info': planner.compiled.model_parameter_info(),
1289
+ 'trace_info': planner.compiled.traced
1290
+ }
1291
+
1292
+ def register_experiment(self, experiment_id: str,
1293
+ planner_info: Dict[str, Any],
1294
+ key: Optional[int]=None,
1295
+ viz: Optional[Any]=None) -> str:
1296
+ '''Starts monitoring a new experiment.'''
1297
+
1298
+ # make sure experiment id does not exist
1299
+ if experiment_id is None:
1300
+ experiment_id = len(self.xticks) + 1
1301
+ if experiment_id in self.xticks:
1302
+ raise ValueError(f'An experiment with id {experiment_id} '
1303
+ 'was already created.')
1304
+
1305
+ self.timestamps[experiment_id] = datetime.fromtimestamp(
1306
+ time.time()).strftime('%Y-%m-%d %H:%M:%S')
1307
+ self.duration[experiment_id] = 0
1308
+ self.seeds[experiment_id] = key
1309
+ self.status[experiment_id] = 'N/A'
1310
+ self.progress[experiment_id] = 0
1311
+ self.warnings = []
1312
+ self.rddl[experiment_id] = planner_info['rddl']
1313
+ self.planner_info[experiment_id] = planner_info['string']
1314
+ self.checked[experiment_id] = False
1315
+
1316
+ self.xticks[experiment_id] = []
1317
+ self.train_return[experiment_id] = []
1318
+ self.test_return[experiment_id] = []
1319
+ self.return_dist_ticks[experiment_id] = []
1320
+ self.return_dist_last_progress[experiment_id] = 0
1321
+ self.return_dist[experiment_id] = []
1322
+ self.action_output[experiment_id] = None
1323
+ self.policy_params[experiment_id] = []
1324
+ self.policy_params_ticks[experiment_id] = []
1325
+ self.policy_params_last_progress[experiment_id] = 0
1326
+ self.policy_viz[experiment_id] = viz
1327
+
1328
+ decompiler = RDDLDecompiler()
1329
+ self.relaxed_exprs[experiment_id] = {}
1330
+ self.relaxed_exprs_values[experiment_id] = {}
1331
+ for info in planner_info['model_parameter_info'].values():
1332
+ expr = planner_info['trace_info'].lookup(info['id'])
1333
+ compiled_expr = decompiler.decompile_expr(expr)
1334
+ self.relaxed_exprs[experiment_id][info['id']] = compiled_expr
1335
+ self.relaxed_exprs_values[experiment_id][info['id']] = []
1336
+
1337
+ return experiment_id
1338
+
1339
+ @staticmethod
1340
+ def representative_trajectories(trajectories, k, max_iter=300):
1341
+ n = next(iter(trajectories.values())).shape[0]
1342
+ points = np.concatenate([
1343
+ np.reshape(1. * values, (n, -1))
1344
+ for values in trajectories.values()
1345
+ ], axis=1)
1346
+
1347
+ k = min(k, n)
1348
+ centroids = points[np.random.choice(n, k, replace=False)]
1349
+ for _ in range(max_iter):
1350
+ distances = np.linalg.norm(
1351
+ points[:, None, :] - centroids[None, :, :], axis=-1)
1352
+ cluster_assignment = np.argmin(distances, axis=1)
1353
+ new_centroids = np.stack([
1354
+ np.mean(points[cluster_assignment == i], axis=0)
1355
+ for i in range(k)
1356
+ ], axis=0)
1357
+ if np.allclose(new_centroids, centroids):
1358
+ break
1359
+ centroids = new_centroids
1360
+ return np.unique(np.argmin(distances, axis=0))
1361
+
1362
+ def update_experiment(self, experiment_id: str, callback: Dict[str, Any]) -> None:
1363
+ '''Pass new information and update the dashboard for a given experiment.'''
1364
+
1365
+ # data for return curves
1366
+ iteration = callback['iteration']
1367
+ self.xticks[experiment_id].append(iteration)
1368
+ self.train_return[experiment_id].append(callback['train_return'])
1369
+ self.test_return[experiment_id].append(callback['best_return'])
1370
+
1371
+ # data for return distributions
1372
+ progress = callback['progress']
1373
+ if progress - self.return_dist_last_progress[experiment_id] \
1374
+ >= PROGRESS_FOR_NEXT_RETURN_DIST:
1375
+ self.return_dist_ticks[experiment_id].append(iteration)
1376
+ self.return_dist[experiment_id].append(
1377
+ np.sum(np.asarray(callback['reward']), axis=1))
1378
+ self.return_dist_last_progress[experiment_id] = progress
1379
+
1380
+ # data for action heatmaps
1381
+ action_output = []
1382
+ rddl = self.rddl[experiment_id]
1383
+ for action in rddl.action_fluents:
1384
+ action_values = np.asarray(callback['fluents'][action])
1385
+ action_output.append(
1386
+ (action_values.reshape(action_values.shape[:2] + (-1,)),
1387
+ action,
1388
+ rddl.variable_groundings[action])
1389
+ )
1390
+ self.action_output[experiment_id] = action_output
1391
+
1392
+ # data for policy weight distributions
1393
+ if progress - self.policy_params_last_progress[experiment_id] \
1394
+ >= PROGRESS_FOR_NEXT_POLICY_DIST:
1395
+ self.policy_params_ticks[experiment_id].append(iteration)
1396
+ self.policy_params[experiment_id].append(callback['best_params'])
1397
+ self.policy_params_last_progress[experiment_id] = progress
1398
+
1399
+ # data for model relaxations
1400
+ model_params = callback['model_params']
1401
+ for (key, values) in model_params.items():
1402
+ expr_id = int(str(key).split('_')[0])
1403
+ self.relaxed_exprs_values[experiment_id][expr_id].append(values.item())
1404
+ self.train_reward_dist[experiment_id] = callback['train_log']['reward']
1405
+ self.test_reward_dist[experiment_id] = callback['reward']
1406
+ self.train_state_fluents[experiment_id] = {
1407
+ name: np.asarray(callback['train_log']['fluents'][name])
1408
+ for name in rddl.state_fluents or name in rddl.observ_fluents
1409
+ }
1410
+ self.test_state_fluents[experiment_id] = {
1411
+ name: np.asarray(callback['fluents'][name])
1412
+ for name in self.train_state_fluents[experiment_id]
1413
+ }
1414
+
1415
+ # update experiment table info
1416
+ self.status[experiment_id] = str(callback['status']).split('.')[1]
1417
+ self.duration[experiment_id] = callback["elapsed_time"]
1418
+ self.progress[experiment_id] = progress
1419
+ self.warnings = None
1420
+
1421
+ def update_tuning(self, optimizer: Any,
1422
+ bounds: Dict[str, Tuple[float, float]]) -> None:
1423
+ '''Updates the hyper-parameter tuning plots.'''
1424
+
1425
+ self.tuning_gp_heatmaps = []
1426
+ self.tuning_gp_update = False
1427
+ if not optimizer.res: return
1428
+
1429
+ self.tuning_gp_targets = optimizer.space.target.reshape((-1,))
1430
+ self.tuning_gp_predicted = \
1431
+ optimizer._gp.predict(optimizer.space.params).reshape((-1,))
1432
+ self.tuning_gp_params = {name: optimizer.space.params[:, i]
1433
+ for (i, name) in enumerate(optimizer.space.keys)}
1434
+
1435
+ for (i1, param1) in enumerate(optimizer.space.keys):
1436
+ self.tuning_gp_heatmaps.append([])
1437
+ for (i2, param2) in enumerate(optimizer.space.keys):
1438
+ if i2 > i1:
1439
+
1440
+ # Generate a grid for visualization
1441
+ p1_values = np.linspace(*bounds[param1], 100)
1442
+ p2_values = np.linspace(*bounds[param2], 100)
1443
+ P1, P2 = np.meshgrid(p1_values, p2_values)
1444
+
1445
+ # Predict the mean and deviation of the surrogate model
1446
+ fixed_params = max(
1447
+ optimizer.res,
1448
+ key=lambda x: x['target'])['params'].copy()
1449
+ fixed_params.pop(param1)
1450
+ fixed_params.pop(param2)
1451
+ param_grid = []
1452
+ for p1, p2 in zip(np.ravel(P1), np.ravel(P2)):
1453
+ params = {param1: p1, param2: p2}
1454
+ params.update(fixed_params)
1455
+ param_grid.append(
1456
+ [params[key] for key in optimizer.space.keys])
1457
+ param_grid = np.asarray(param_grid)
1458
+ mean, std = optimizer._gp.predict(param_grid, return_std=True)
1459
+ mean = mean.reshape(P1.shape)
1460
+ std = std.reshape(P1.shape)
1461
+ self.tuning_gp_heatmaps[-1].append(
1462
+ (param1, param2, p1_values, p2_values, mean, std))
1463
+ self.tuning_gp_update = True