pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1463 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import math
|
|
5
|
+
import numpy as np
|
|
6
|
+
import time
|
|
7
|
+
import threading
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
9
|
+
import warnings
|
|
10
|
+
import webbrowser
|
|
11
|
+
|
|
12
|
+
# prevent endless console prints
|
|
13
|
+
import logging
|
|
14
|
+
log = logging.getLogger('werkzeug')
|
|
15
|
+
log.setLevel(logging.ERROR)
|
|
16
|
+
|
|
17
|
+
import dash
|
|
18
|
+
from dash.dcc import Interval, Graph, Store
|
|
19
|
+
from dash.dependencies import Input, Output, State, ALL
|
|
20
|
+
from dash.html import Div, B, H4, P, Img, Hr
|
|
21
|
+
import dash_bootstrap_components as dbc
|
|
22
|
+
|
|
23
|
+
import plotly.colors as pc
|
|
24
|
+
import plotly.graph_objs as go
|
|
25
|
+
from plotly.subplots import make_subplots
|
|
26
|
+
|
|
27
|
+
from pyRDDLGym.core.debug.decompiler import RDDLDecompiler
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner
|
|
31
|
+
|
|
32
|
+
POLICY_DIST_HEIGHT = 400
|
|
33
|
+
POLICY_DIST_PLOTS_PER_ROW = 6
|
|
34
|
+
ACTION_HEATMAP_HEIGHT = 400
|
|
35
|
+
PROGRESS_FOR_NEXT_RETURN_DIST = 2
|
|
36
|
+
PROGRESS_FOR_NEXT_POLICY_DIST = 10
|
|
37
|
+
REWARD_ERROR_DIST_SUBPLOTS = 20
|
|
38
|
+
MODEL_STATE_ERROR_HEIGHT = 300
|
|
39
|
+
POLICY_STATE_VIZ_MAX_HEIGHT = 800
|
|
40
|
+
GP_POSTERIOR_MAX_HEIGHT = 800
|
|
41
|
+
|
|
42
|
+
PLOT_AXES_FONT_SIZE = 11
|
|
43
|
+
EXPERIMENT_ENTRY_FONT_SIZE = 14
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class JaxPlannerDashboard:
|
|
47
|
+
'''A dashboard app for monitoring the jax planner progress.'''
|
|
48
|
+
|
|
49
|
+
def __init__(self, theme: str=dbc.themes.CERULEAN) -> None:
|
|
50
|
+
|
|
51
|
+
self.timestamps = {}
|
|
52
|
+
self.duration = {}
|
|
53
|
+
self.seeds = {}
|
|
54
|
+
self.status = {}
|
|
55
|
+
self.warnings = []
|
|
56
|
+
self.progress = {}
|
|
57
|
+
self.checked = {}
|
|
58
|
+
self.rddl = {}
|
|
59
|
+
self.planner_info = {}
|
|
60
|
+
|
|
61
|
+
self.xticks = {}
|
|
62
|
+
self.test_return = {}
|
|
63
|
+
self.train_return = {}
|
|
64
|
+
self.return_dist = {}
|
|
65
|
+
self.return_dist_ticks = {}
|
|
66
|
+
self.return_dist_last_progress = {}
|
|
67
|
+
self.action_output = {}
|
|
68
|
+
self.policy_params = {}
|
|
69
|
+
self.policy_params_ticks = {}
|
|
70
|
+
self.policy_params_last_progress = {}
|
|
71
|
+
self.policy_viz = {}
|
|
72
|
+
|
|
73
|
+
self.relaxed_exprs = {}
|
|
74
|
+
self.relaxed_exprs_values = {}
|
|
75
|
+
self.train_reward_dist = {}
|
|
76
|
+
self.test_reward_dist = {}
|
|
77
|
+
self.train_state_fluents = {}
|
|
78
|
+
self.test_state_fluents = {}
|
|
79
|
+
|
|
80
|
+
self.tuning_gp_heatmaps = None
|
|
81
|
+
self.tuning_gp_targets = None
|
|
82
|
+
self.tuning_gp_predicted = None
|
|
83
|
+
self.tuning_gp_params = None
|
|
84
|
+
self.tuning_gp_update = False
|
|
85
|
+
|
|
86
|
+
# ======================================================================
|
|
87
|
+
# CREATE PAGE LAYOUT
|
|
88
|
+
# ======================================================================
|
|
89
|
+
|
|
90
|
+
def create_experiment_table(active_page, page_size):
|
|
91
|
+
start = (active_page - 1) * page_size
|
|
92
|
+
end = start + page_size
|
|
93
|
+
rows = []
|
|
94
|
+
|
|
95
|
+
# header
|
|
96
|
+
row = dbc.Row([
|
|
97
|
+
dbc.Col([
|
|
98
|
+
dbc.Card(dbc.CardBody(
|
|
99
|
+
B('Display'), style={"padding": "0"}
|
|
100
|
+
), className="border-0 bg-transparent")
|
|
101
|
+
], width=1),
|
|
102
|
+
dbc.Col([
|
|
103
|
+
dbc.Card(dbc.CardBody(
|
|
104
|
+
B('Experiment ID'), style={"padding": "0"}
|
|
105
|
+
), className="border-0 bg-transparent"),
|
|
106
|
+
], width=2),
|
|
107
|
+
dbc.Col([
|
|
108
|
+
dbc.Card(dbc.CardBody(
|
|
109
|
+
B('Seed'), style={"padding": "0"}
|
|
110
|
+
), className="border-0 bg-transparent")
|
|
111
|
+
], width=1),
|
|
112
|
+
dbc.Col([
|
|
113
|
+
dbc.Card(dbc.CardBody(
|
|
114
|
+
B('Timestamp'), style={"padding": "0"}
|
|
115
|
+
), className="border-0 bg-transparent")
|
|
116
|
+
], width=2),
|
|
117
|
+
dbc.Col([
|
|
118
|
+
dbc.Card(dbc.CardBody(
|
|
119
|
+
B('Duration'), style={"padding": "0"}
|
|
120
|
+
), className="border-0 bg-transparent")
|
|
121
|
+
], width=1),
|
|
122
|
+
dbc.Col([
|
|
123
|
+
dbc.Card(dbc.CardBody(
|
|
124
|
+
B('Best Return'), style={"padding": "0"}
|
|
125
|
+
), className="border-0 bg-transparent")
|
|
126
|
+
], width=2),
|
|
127
|
+
dbc.Col([
|
|
128
|
+
dbc.Card(dbc.CardBody(
|
|
129
|
+
B('Status'), style={"padding": "0"}
|
|
130
|
+
), className="border-0 bg-transparent")
|
|
131
|
+
], width=2),
|
|
132
|
+
dbc.Col([
|
|
133
|
+
dbc.Card(dbc.CardBody(
|
|
134
|
+
B('Progress'), style={"padding": "0"}
|
|
135
|
+
), className="border-0 bg-transparent")
|
|
136
|
+
], width=1)
|
|
137
|
+
])
|
|
138
|
+
rows.append(row)
|
|
139
|
+
|
|
140
|
+
# experiments
|
|
141
|
+
for (i, id) in enumerate(self.checked):
|
|
142
|
+
if i >= end:
|
|
143
|
+
break
|
|
144
|
+
if i >= start and i < end:
|
|
145
|
+
progress = self.progress[id]
|
|
146
|
+
row = dbc.Row([
|
|
147
|
+
dbc.Col([
|
|
148
|
+
dbc.Card(
|
|
149
|
+
dbc.CardBody([
|
|
150
|
+
dbc.Checkbox(
|
|
151
|
+
id={'type': 'experiment-checkbox', 'index': id},
|
|
152
|
+
value=self.checked[id]
|
|
153
|
+
)],
|
|
154
|
+
style={"padding": "0"}
|
|
155
|
+
),
|
|
156
|
+
className="border-0 bg-transparent"
|
|
157
|
+
)
|
|
158
|
+
], width=1),
|
|
159
|
+
dbc.Col([
|
|
160
|
+
dbc.Card(
|
|
161
|
+
dbc.CardBody(id, style={"padding": "0"}),
|
|
162
|
+
className="border-0 bg-transparent"
|
|
163
|
+
)
|
|
164
|
+
], width=2),
|
|
165
|
+
dbc.Col([
|
|
166
|
+
dbc.Card(
|
|
167
|
+
dbc.CardBody(self.seeds[id], style={"padding": "0"}),
|
|
168
|
+
className="border-0 bg-transparent"
|
|
169
|
+
)
|
|
170
|
+
], width=1),
|
|
171
|
+
dbc.Col([
|
|
172
|
+
dbc.Card(
|
|
173
|
+
dbc.CardBody(self.timestamps[id], style={"padding": "0"}),
|
|
174
|
+
className="border-0 bg-transparent"
|
|
175
|
+
),
|
|
176
|
+
], width=2),
|
|
177
|
+
dbc.Col([
|
|
178
|
+
dbc.Card(
|
|
179
|
+
dbc.CardBody(f'{self.duration[id]:.3f}s', style={"padding": "0"}),
|
|
180
|
+
className="border-0 bg-transparent"
|
|
181
|
+
),
|
|
182
|
+
], width=1),
|
|
183
|
+
dbc.Col([
|
|
184
|
+
dbc.Card(
|
|
185
|
+
dbc.CardBody(f'{(self.test_return[id] or [np.nan])[-1]:.3f}',
|
|
186
|
+
style={"padding": "0"}),
|
|
187
|
+
className="border-0 bg-transparent"
|
|
188
|
+
),
|
|
189
|
+
], width=2),
|
|
190
|
+
dbc.Col([
|
|
191
|
+
dbc.Card(
|
|
192
|
+
dbc.CardBody(self.status[id], style={"padding": "0"}),
|
|
193
|
+
className="border-0 bg-transparent"
|
|
194
|
+
),
|
|
195
|
+
], width=2),
|
|
196
|
+
dbc.Col([
|
|
197
|
+
dbc.Card(
|
|
198
|
+
dbc.CardBody(
|
|
199
|
+
[dbc.Progress(label=f"{progress}%", value=progress)],
|
|
200
|
+
style={"padding": "0"}
|
|
201
|
+
),
|
|
202
|
+
className="border-0 bg-transparent"
|
|
203
|
+
),
|
|
204
|
+
], width=1),
|
|
205
|
+
])
|
|
206
|
+
rows.append(row)
|
|
207
|
+
return rows
|
|
208
|
+
|
|
209
|
+
app = dash.Dash(__name__, external_stylesheets=[theme])
|
|
210
|
+
app.title = 'JaxPlan Dashboard'
|
|
211
|
+
|
|
212
|
+
app.layout = dbc.Container([
|
|
213
|
+
Store(id='refresh-interval', data=2000),
|
|
214
|
+
Store(id='experiment-num-per-page', data=10),
|
|
215
|
+
Store(id='model-params-dropdown-expr', data=''),
|
|
216
|
+
Store(id='model-errors-state-dropdown-selected', data=''),
|
|
217
|
+
Store(id='viz-skip-frequency', data=5),
|
|
218
|
+
Store(id='viz-num-trajectories', data=3),
|
|
219
|
+
Div(id='viewport-sizer', style={'display': 'none'}),
|
|
220
|
+
|
|
221
|
+
# navbar
|
|
222
|
+
dbc.Navbar(
|
|
223
|
+
dbc.Container([
|
|
224
|
+
# Img(src=LOGO_FILE, height="30px", style={'margin-right': '10px'}),
|
|
225
|
+
dbc.NavbarBrand(f"JaxPlan Dashboard"),
|
|
226
|
+
dbc.Nav([
|
|
227
|
+
dbc.NavItem(
|
|
228
|
+
dbc.NavLink(
|
|
229
|
+
"Docs",
|
|
230
|
+
href="https://pyrddlgym.readthedocs.io/en/latest/jax.html"
|
|
231
|
+
)
|
|
232
|
+
),
|
|
233
|
+
dbc.NavItem(
|
|
234
|
+
dbc.NavLink(
|
|
235
|
+
"GitHub",
|
|
236
|
+
href="https://github.com/pyrddlgym-project/pyRDDLGym-jax"
|
|
237
|
+
)
|
|
238
|
+
),
|
|
239
|
+
dbc.NavItem(
|
|
240
|
+
dbc.NavLink(
|
|
241
|
+
"Submit an Issue",
|
|
242
|
+
href="https://github.com/pyrddlgym-project/pyRDDLGym-jax/issues"
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
], navbar=True, className="me-auto"),
|
|
246
|
+
dbc.Nav([
|
|
247
|
+
dbc.DropdownMenu(
|
|
248
|
+
[dbc.DropdownMenuItem("500ms", id='05sec'),
|
|
249
|
+
dbc.DropdownMenuItem("1s", id='1sec'),
|
|
250
|
+
dbc.DropdownMenuItem("2s", id='2sec'),
|
|
251
|
+
dbc.DropdownMenuItem("5s", id='5sec'),
|
|
252
|
+
dbc.DropdownMenuItem("10s", id='10sec'),
|
|
253
|
+
dbc.DropdownMenuItem("30s", id='30sec'),
|
|
254
|
+
dbc.DropdownMenuItem("1m", id='1min'),
|
|
255
|
+
dbc.DropdownMenuItem("5m", id='5min'),
|
|
256
|
+
dbc.DropdownMenuItem("1d", id='1day')],
|
|
257
|
+
label="Refresh: 2s",
|
|
258
|
+
id='refresh-rate-dropdown',
|
|
259
|
+
nav=True
|
|
260
|
+
)
|
|
261
|
+
], navbar=True),
|
|
262
|
+
dbc.Nav([
|
|
263
|
+
dbc.DropdownMenu(
|
|
264
|
+
[dbc.DropdownMenuItem("5", id='5pp'),
|
|
265
|
+
dbc.DropdownMenuItem("10", id='10pp'),
|
|
266
|
+
dbc.DropdownMenuItem("25", id='25pp'),
|
|
267
|
+
dbc.DropdownMenuItem("50", id='50pp')],
|
|
268
|
+
label="Exp. Per Page: 10",
|
|
269
|
+
id='experiment-num-per-page-dropdown',
|
|
270
|
+
nav=True
|
|
271
|
+
)
|
|
272
|
+
], navbar=True)
|
|
273
|
+
], fluid=True)
|
|
274
|
+
),
|
|
275
|
+
|
|
276
|
+
# experiments
|
|
277
|
+
dbc.Row([
|
|
278
|
+
dbc.Col([
|
|
279
|
+
dbc.Card([
|
|
280
|
+
dbc.CardBody([
|
|
281
|
+
Div(create_experiment_table(0, 10), id='experiment-table',
|
|
282
|
+
style={'fontSize': f'{EXPERIMENT_ENTRY_FONT_SIZE}px'}),
|
|
283
|
+
dbc.Pagination(id='experiment-pagination',
|
|
284
|
+
active_page=1, max_value=1, size="sm")
|
|
285
|
+
], style={'padding': '10px'})
|
|
286
|
+
], className="border-0 bg-transparent")
|
|
287
|
+
])
|
|
288
|
+
]),
|
|
289
|
+
|
|
290
|
+
# empirical results tabs
|
|
291
|
+
dbc.Row([
|
|
292
|
+
dbc.Col([
|
|
293
|
+
dbc.Tabs([
|
|
294
|
+
|
|
295
|
+
# returns
|
|
296
|
+
dbc.Tab(dbc.Card(
|
|
297
|
+
dbc.CardBody([
|
|
298
|
+
dbc.Row([
|
|
299
|
+
dbc.Col(Graph(id='train-return-graph'), width=6),
|
|
300
|
+
dbc.Col(Graph(id='test-return-graph'), width=6),
|
|
301
|
+
]),
|
|
302
|
+
dbc.Row([
|
|
303
|
+
Graph(id='dist-return-graph')
|
|
304
|
+
])
|
|
305
|
+
]), className="border-0 bg-transparent"
|
|
306
|
+
), label="Performance", tab_id='tab-performance'
|
|
307
|
+
),
|
|
308
|
+
|
|
309
|
+
# policy
|
|
310
|
+
dbc.Tab(dbc.Card(
|
|
311
|
+
dbc.CardBody([
|
|
312
|
+
dbc.Row([
|
|
313
|
+
Graph(id='action-output'),
|
|
314
|
+
]),
|
|
315
|
+
dbc.Row([
|
|
316
|
+
dbc.Col([
|
|
317
|
+
Hr(className='my-4')
|
|
318
|
+
])
|
|
319
|
+
]),
|
|
320
|
+
dbc.Row([
|
|
321
|
+
Graph(id='policy-params'),
|
|
322
|
+
]),
|
|
323
|
+
dbc.Row([
|
|
324
|
+
dbc.Col([
|
|
325
|
+
Hr(className='my-4')
|
|
326
|
+
])
|
|
327
|
+
]),
|
|
328
|
+
dbc.Row([
|
|
329
|
+
dbc.Col([
|
|
330
|
+
dbc.Row([
|
|
331
|
+
dbc.Col([
|
|
332
|
+
dbc.DropdownMenu(
|
|
333
|
+
[dbc.DropdownMenuItem('Every Frame', id='viz-skip-1'),
|
|
334
|
+
dbc.DropdownMenuItem('Every 2 Frames', id='viz-skip-2'),
|
|
335
|
+
dbc.DropdownMenuItem('Every 3 Frames', id='viz-skip-3'),
|
|
336
|
+
dbc.DropdownMenuItem('Every 4 Frames', id='viz-skip-4'),
|
|
337
|
+
dbc.DropdownMenuItem('Every 5 Frames', id='viz-skip-5'),
|
|
338
|
+
dbc.DropdownMenuItem('Every 10 Frames', id='viz-skip-10')],
|
|
339
|
+
label="Render: Every 5 Frames",
|
|
340
|
+
id='viz-skip-dropdown'
|
|
341
|
+
),
|
|
342
|
+
], width='auto'),
|
|
343
|
+
dbc.Col([
|
|
344
|
+
dbc.DropdownMenu(
|
|
345
|
+
[dbc.DropdownMenuItem('1', id='viz-num-1'),
|
|
346
|
+
dbc.DropdownMenuItem('2', id='viz-num-2'),
|
|
347
|
+
dbc.DropdownMenuItem('3', id='viz-num-3'),
|
|
348
|
+
dbc.DropdownMenuItem('4', id='viz-num-4'),
|
|
349
|
+
dbc.DropdownMenuItem('5', id='viz-num-5')],
|
|
350
|
+
label="Max. Trajectories: 3",
|
|
351
|
+
id='viz-num-dropdown'
|
|
352
|
+
),
|
|
353
|
+
], width='auto'),
|
|
354
|
+
dbc.Col([
|
|
355
|
+
dbc.Button('Run Policy Visualization',
|
|
356
|
+
id='policy-viz-button'),
|
|
357
|
+
], width='auto')
|
|
358
|
+
]),
|
|
359
|
+
dbc.Row([
|
|
360
|
+
Graph(id='policy-viz')
|
|
361
|
+
])
|
|
362
|
+
])
|
|
363
|
+
]),
|
|
364
|
+
]), className="border-0 bg-transparent"
|
|
365
|
+
), label="Policy", tab_id='tab-policy'
|
|
366
|
+
),
|
|
367
|
+
|
|
368
|
+
# model
|
|
369
|
+
dbc.Tab(dbc.Card(
|
|
370
|
+
dbc.CardBody([
|
|
371
|
+
dbc.Row([
|
|
372
|
+
dbc.Col([
|
|
373
|
+
dbc.Card([
|
|
374
|
+
dbc.DropdownMenu(
|
|
375
|
+
[],
|
|
376
|
+
label="RDDL Expression",
|
|
377
|
+
id='model-params-dropdown'
|
|
378
|
+
),
|
|
379
|
+
Graph(id='model-params-graph')
|
|
380
|
+
], className="border-0 bg-transparent"
|
|
381
|
+
),
|
|
382
|
+
])
|
|
383
|
+
]),
|
|
384
|
+
dbc.Row([
|
|
385
|
+
dbc.Col([
|
|
386
|
+
Hr(className='my-4')
|
|
387
|
+
])
|
|
388
|
+
]),
|
|
389
|
+
dbc.Row([
|
|
390
|
+
dbc.Col([
|
|
391
|
+
dbc.Card(
|
|
392
|
+
dbc.CardBody(
|
|
393
|
+
Graph(id='model-errors-reward-graph')
|
|
394
|
+
),
|
|
395
|
+
className="border-0 bg-transparent"
|
|
396
|
+
),
|
|
397
|
+
])
|
|
398
|
+
]),
|
|
399
|
+
dbc.Row([
|
|
400
|
+
dbc.Col([
|
|
401
|
+
dbc.Card([
|
|
402
|
+
dbc.DropdownMenu(
|
|
403
|
+
[],
|
|
404
|
+
label="State-Fluent",
|
|
405
|
+
id='model-errors-state-dropdown'
|
|
406
|
+
),
|
|
407
|
+
Graph(id='model-errors-state-graph')
|
|
408
|
+
], className="border-0 bg-transparent"
|
|
409
|
+
)
|
|
410
|
+
])
|
|
411
|
+
]),
|
|
412
|
+
]), className="border-0 bg-transparent"
|
|
413
|
+
), label="Model", tab_id='tab-model'
|
|
414
|
+
),
|
|
415
|
+
|
|
416
|
+
# information
|
|
417
|
+
dbc.Tab(dbc.Card(
|
|
418
|
+
dbc.CardBody([
|
|
419
|
+
dbc.Row([
|
|
420
|
+
dbc.Alert(id="planner-info", color="light", dismissable=False,
|
|
421
|
+
style={"fontFamily": "Courier, monospace"})
|
|
422
|
+
]),
|
|
423
|
+
]), className="border-0 bg-transparent"
|
|
424
|
+
), label="Debug", tab_id='tab-debug'
|
|
425
|
+
),
|
|
426
|
+
|
|
427
|
+
# tuning
|
|
428
|
+
dbc.Tab(dbc.Card(
|
|
429
|
+
dbc.CardBody([
|
|
430
|
+
dbc.Row([
|
|
431
|
+
dbc.Col(Graph(id='tuning-target-graph'), width=6),
|
|
432
|
+
dbc.Col(Graph(id='tuning-scatter-graph'), width=6)
|
|
433
|
+
]),
|
|
434
|
+
dbc.Row([
|
|
435
|
+
dbc.Col([
|
|
436
|
+
Hr(className='my-4')
|
|
437
|
+
])
|
|
438
|
+
]),
|
|
439
|
+
dbc.Row([
|
|
440
|
+
dbc.Col(Graph(id='tuning-gp-mean-graph'))
|
|
441
|
+
]),
|
|
442
|
+
dbc.Row([
|
|
443
|
+
dbc.Col(Graph(id='tuning-gp-unc-graph'))
|
|
444
|
+
])
|
|
445
|
+
]), className="border-0 bg-transparent"
|
|
446
|
+
), label="Tuning", tab_id='tab-tuning'
|
|
447
|
+
),
|
|
448
|
+
], id='tabs-main')
|
|
449
|
+
], width=12)
|
|
450
|
+
]),
|
|
451
|
+
|
|
452
|
+
# refresh interval
|
|
453
|
+
Interval(
|
|
454
|
+
id='interval',
|
|
455
|
+
interval=2000,
|
|
456
|
+
n_intervals=0
|
|
457
|
+
),
|
|
458
|
+
Div(id='trigger-experiment-check', style={'display': 'none'})
|
|
459
|
+
], fluid=True, className="dbc")
|
|
460
|
+
|
|
461
|
+
# JavaScript to retrieve the viewport dimensions
|
|
462
|
+
app.clientside_callback(
|
|
463
|
+
"""
|
|
464
|
+
function(n_intervals) {
|
|
465
|
+
return {
|
|
466
|
+
'height': window.innerHeight,
|
|
467
|
+
'width': window.innerWidth
|
|
468
|
+
};
|
|
469
|
+
}
|
|
470
|
+
""",
|
|
471
|
+
Output('viewport-sizer', 'children'),
|
|
472
|
+
Input('interval', 'n_intervals')
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# ======================================================================
|
|
476
|
+
# CREATE EVENTS
|
|
477
|
+
# ======================================================================
|
|
478
|
+
|
|
479
|
+
# modify refresh rate
|
|
480
|
+
@app.callback(
|
|
481
|
+
Output("refresh-interval", "data"),
|
|
482
|
+
[Input("05sec", "n_clicks"),
|
|
483
|
+
Input("1sec", "n_clicks"),
|
|
484
|
+
Input("2sec", "n_clicks"),
|
|
485
|
+
Input("5sec", "n_clicks"),
|
|
486
|
+
Input("10sec", "n_clicks"),
|
|
487
|
+
Input("30sec", "n_clicks"),
|
|
488
|
+
Input("1min", "n_clicks"),
|
|
489
|
+
Input("5min", "n_clicks"),
|
|
490
|
+
Input("1day", "n_clicks")],
|
|
491
|
+
[State('refresh-interval', 'data')]
|
|
492
|
+
)
|
|
493
|
+
def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, nd, data):
|
|
494
|
+
ctx = dash.callback_context
|
|
495
|
+
if not ctx.triggered:
|
|
496
|
+
return data
|
|
497
|
+
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
|
498
|
+
if button_id == '05sec':
|
|
499
|
+
return 500
|
|
500
|
+
elif button_id == '1sec':
|
|
501
|
+
return 1000
|
|
502
|
+
elif button_id == '2sec':
|
|
503
|
+
return 2000
|
|
504
|
+
elif button_id == '5sec':
|
|
505
|
+
return 5000
|
|
506
|
+
elif button_id == '10sec':
|
|
507
|
+
return 10000
|
|
508
|
+
elif button_id == '30sec':
|
|
509
|
+
return 30000
|
|
510
|
+
elif button_id == '1min':
|
|
511
|
+
return 60000
|
|
512
|
+
elif button_id == '5min':
|
|
513
|
+
return 300000
|
|
514
|
+
elif button_id == '1day':
|
|
515
|
+
return 86400000
|
|
516
|
+
return data
|
|
517
|
+
|
|
518
|
+
@app.callback(
|
|
519
|
+
Output('interval', 'interval'),
|
|
520
|
+
[Input('refresh-interval', 'data')]
|
|
521
|
+
)
|
|
522
|
+
def update_refresh_rate(selected_interval):
|
|
523
|
+
return selected_interval if selected_interval else 2000
|
|
524
|
+
|
|
525
|
+
@app.callback(
|
|
526
|
+
Output('refresh-rate-dropdown', 'label'),
|
|
527
|
+
[Input('refresh-interval', 'data')]
|
|
528
|
+
)
|
|
529
|
+
def update_refresh_rate(selected_interval):
|
|
530
|
+
if selected_interval == 500:
|
|
531
|
+
return 'Refresh: 500ms'
|
|
532
|
+
elif selected_interval == 1000:
|
|
533
|
+
return 'Refresh: 1s'
|
|
534
|
+
elif selected_interval == 2000:
|
|
535
|
+
return 'Refresh: 2s'
|
|
536
|
+
elif selected_interval == 5000:
|
|
537
|
+
return 'Refresh: 5s'
|
|
538
|
+
elif selected_interval == 10000:
|
|
539
|
+
return 'Refresh: 10s'
|
|
540
|
+
elif selected_interval == 30000:
|
|
541
|
+
return 'Refresh: 30s'
|
|
542
|
+
elif selected_interval == 60000:
|
|
543
|
+
return 'Refresh: 1m'
|
|
544
|
+
elif selected_interval == 300000:
|
|
545
|
+
return 'Refresh: 5m'
|
|
546
|
+
else:
|
|
547
|
+
return 'Refresh: 2s'
|
|
548
|
+
|
|
549
|
+
# update the experiments per page
|
|
550
|
+
@app.callback(
|
|
551
|
+
Output("experiment-num-per-page", "data"),
|
|
552
|
+
[Input("5pp", "n_clicks"),
|
|
553
|
+
Input("10pp", "n_clicks"),
|
|
554
|
+
Input("25pp", "n_clicks"),
|
|
555
|
+
Input("50pp", "n_clicks")],
|
|
556
|
+
[State('experiment-num-per-page', 'data')]
|
|
557
|
+
)
|
|
558
|
+
def click_experiments_per_page(n5, n10, n25, n50, data):
|
|
559
|
+
ctx = dash.callback_context
|
|
560
|
+
if not ctx.triggered:
|
|
561
|
+
return data
|
|
562
|
+
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
|
563
|
+
if button_id == '5pp':
|
|
564
|
+
return 5
|
|
565
|
+
elif button_id == '10pp':
|
|
566
|
+
return 10
|
|
567
|
+
elif button_id == '25pp':
|
|
568
|
+
return 25
|
|
569
|
+
elif button_id == '50pp':
|
|
570
|
+
return 50
|
|
571
|
+
return data
|
|
572
|
+
|
|
573
|
+
@app.callback(
|
|
574
|
+
Output('experiment-num-per-page-dropdown', 'label'),
|
|
575
|
+
[Input('experiment-num-per-page', 'data')]
|
|
576
|
+
)
|
|
577
|
+
def update_experiments_per_page(selected_num):
|
|
578
|
+
return f'Exp. Per Page: {selected_num}'
|
|
579
|
+
|
|
580
|
+
# update the experiment table
|
|
581
|
+
@app.callback(
|
|
582
|
+
Output('experiment-table', 'children'),
|
|
583
|
+
[Input('interval', 'n_intervals'),
|
|
584
|
+
Input('experiment-pagination', 'active_page')],
|
|
585
|
+
State('experiment-num-per-page', 'data')
|
|
586
|
+
)
|
|
587
|
+
def update_experiment_table(n, active_page, npp):
|
|
588
|
+
return create_experiment_table(active_page, npp)
|
|
589
|
+
|
|
590
|
+
@app.callback(
|
|
591
|
+
Output('experiment-pagination', 'max_value'),
|
|
592
|
+
Input('interval', 'n_intervals'),
|
|
593
|
+
State('experiment-num-per-page', 'data')
|
|
594
|
+
)
|
|
595
|
+
def update_experiment_max_pages(n, npp):
|
|
596
|
+
return (len(self.checked) + npp - 1) // npp
|
|
597
|
+
|
|
598
|
+
@app.callback(
|
|
599
|
+
Output('trigger-experiment-check', 'children'),
|
|
600
|
+
Input({'type': 'experiment-checkbox', 'index': ALL}, 'value'),
|
|
601
|
+
State({'type': 'experiment-checkbox', 'index': ALL}, 'id')
|
|
602
|
+
)
|
|
603
|
+
def update_checked_experiment_status(checked, ids):
|
|
604
|
+
for (i, chk) in enumerate(checked):
|
|
605
|
+
row = ids[i]['index']
|
|
606
|
+
self.checked[row] = chk
|
|
607
|
+
return time.time()
|
|
608
|
+
|
|
609
|
+
# update the return information
|
|
610
|
+
@app.callback(
|
|
611
|
+
Output('train-return-graph', 'figure'),
|
|
612
|
+
[Input('interval', 'n_intervals'),
|
|
613
|
+
Input('trigger-experiment-check', 'children'),
|
|
614
|
+
Input('tabs-main', 'active_tab')]
|
|
615
|
+
)
|
|
616
|
+
def update_train_return_graph(n, trigger, active_tab):
|
|
617
|
+
if active_tab != 'tab-performance': return dash.no_update
|
|
618
|
+
fig = go.Figure()
|
|
619
|
+
for (row, checked) in self.checked.copy().items():
|
|
620
|
+
if checked:
|
|
621
|
+
fig.add_trace(go.Scatter(
|
|
622
|
+
x=self.xticks[row], y=self.train_return[row],
|
|
623
|
+
name=f'id={row}',
|
|
624
|
+
mode='lines+markers',
|
|
625
|
+
marker=dict(size=3), line=dict(width=2)
|
|
626
|
+
))
|
|
627
|
+
fig.update_layout(
|
|
628
|
+
title=dict(text="Train Return"),
|
|
629
|
+
xaxis=dict(title=dict(text="Training Iteration")),
|
|
630
|
+
yaxis=dict(title=dict(text="Cumulative Reward")),
|
|
631
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
632
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
633
|
+
template="plotly_white"
|
|
634
|
+
)
|
|
635
|
+
return fig
|
|
636
|
+
|
|
637
|
+
@app.callback(
|
|
638
|
+
Output('test-return-graph', 'figure'),
|
|
639
|
+
[Input('interval', 'n_intervals'),
|
|
640
|
+
Input('trigger-experiment-check', 'children'),
|
|
641
|
+
Input('tabs-main', 'active_tab')]
|
|
642
|
+
)
|
|
643
|
+
def update_test_return_graph(n, trigger, active_tab):
|
|
644
|
+
if active_tab != 'tab-performance': return dash.no_update
|
|
645
|
+
fig = go.Figure()
|
|
646
|
+
for (row, checked) in self.checked.copy().items():
|
|
647
|
+
if checked:
|
|
648
|
+
fig.add_trace(go.Scatter(
|
|
649
|
+
x=self.xticks[row], y=self.test_return[row],
|
|
650
|
+
name=f'id={row}',
|
|
651
|
+
mode='lines+markers',
|
|
652
|
+
marker=dict(size=3), line=dict(width=2)
|
|
653
|
+
))
|
|
654
|
+
fig.update_layout(
|
|
655
|
+
title=dict(text="Test Return"),
|
|
656
|
+
xaxis=dict(title=dict(text="Training Iteration")),
|
|
657
|
+
yaxis=dict(title=dict(text="Cumulative Reward")),
|
|
658
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
659
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
660
|
+
template="plotly_white"
|
|
661
|
+
)
|
|
662
|
+
return fig
|
|
663
|
+
|
|
664
|
+
@app.callback(
|
|
665
|
+
Output('dist-return-graph', 'figure'),
|
|
666
|
+
[Input('interval', 'n_intervals'),
|
|
667
|
+
Input('trigger-experiment-check', 'children'),
|
|
668
|
+
Input('tabs-main', 'active_tab')]
|
|
669
|
+
)
|
|
670
|
+
def update_dist_return_graph(n, trigger, active_tab):
|
|
671
|
+
if active_tab != 'tab-performance': return dash.no_update
|
|
672
|
+
fig = go.Figure()
|
|
673
|
+
fig.update_layout(template='plotly_white')
|
|
674
|
+
for (row, checked) in self.checked.copy().items():
|
|
675
|
+
if checked:
|
|
676
|
+
return_dists = self.return_dist[row]
|
|
677
|
+
ticks = self.return_dist_ticks[row]
|
|
678
|
+
colors = pc.sample_colorscale(
|
|
679
|
+
pc.get_colorscale('Blues'),
|
|
680
|
+
np.linspace(0.1, 1, len(return_dists))
|
|
681
|
+
)
|
|
682
|
+
for (ic, (tick, dist)) in enumerate(zip(ticks, return_dists)):
|
|
683
|
+
fig.add_trace(go.Violin(
|
|
684
|
+
y=dist, line_color=colors[ic], name=f'{tick}'
|
|
685
|
+
))
|
|
686
|
+
fig.update_traces(
|
|
687
|
+
orientation='v', side='positive', width=3, points=False)
|
|
688
|
+
fig.update_layout(
|
|
689
|
+
title=dict(text="Distribution of Return"),
|
|
690
|
+
xaxis=dict(title=dict(text="Training Iteration")),
|
|
691
|
+
yaxis=dict(title=dict(text="Cumulative Reward")),
|
|
692
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
693
|
+
showlegend=False,
|
|
694
|
+
yaxis_showgrid=False, yaxis_zeroline=False,
|
|
695
|
+
template="plotly_white"
|
|
696
|
+
)
|
|
697
|
+
break
|
|
698
|
+
return fig
|
|
699
|
+
|
|
700
|
+
# update the action heatmap
|
|
701
|
+
@app.callback(
|
|
702
|
+
Output('action-output', 'figure'),
|
|
703
|
+
[Input('interval', 'n_intervals'),
|
|
704
|
+
Input('trigger-experiment-check', 'children'),
|
|
705
|
+
Input('tabs-main', 'active_tab')]
|
|
706
|
+
)
|
|
707
|
+
def update_action_heatmap(n, trigger, active_tab):
|
|
708
|
+
if active_tab != 'tab-policy': return dash.no_update
|
|
709
|
+
fig = go.Figure()
|
|
710
|
+
fig.update_layout(template='plotly_white')
|
|
711
|
+
for (row, checked) in self.checked.copy().items():
|
|
712
|
+
if checked and self.action_output[row] is not None:
|
|
713
|
+
num_plots = len(self.action_output[row])
|
|
714
|
+
titles = []
|
|
715
|
+
for (_, act, _) in self.action_output[row]:
|
|
716
|
+
titles.append(f'Values of Action-Fluents {act}')
|
|
717
|
+
titles.append(f'Std. Dev. of Action-Fluents {act}')
|
|
718
|
+
fig = make_subplots(
|
|
719
|
+
rows=num_plots, cols=2,
|
|
720
|
+
shared_xaxes=True, horizontal_spacing=0.15,
|
|
721
|
+
subplot_titles=titles
|
|
722
|
+
)
|
|
723
|
+
for (i, (action_output, action, action_labels)) \
|
|
724
|
+
in enumerate(self.action_output[row]):
|
|
725
|
+
action_values = np.mean(1. * action_output, axis=0).T
|
|
726
|
+
action_errors = np.std(1. * action_output, axis=0).T
|
|
727
|
+
fig.add_trace(go.Heatmap(
|
|
728
|
+
z=action_values,
|
|
729
|
+
x=np.arange(action_values.shape[1]),
|
|
730
|
+
y=np.arange(action_values.shape[0]),
|
|
731
|
+
colorscale='Blues', colorbar_x=0.45,
|
|
732
|
+
colorbar_len=0.8 / num_plots,
|
|
733
|
+
colorbar_y=1 - (i + 0.5) / num_plots
|
|
734
|
+
), row=i + 1, col=1)
|
|
735
|
+
fig.add_trace(go.Heatmap(
|
|
736
|
+
z=action_errors,
|
|
737
|
+
x=np.arange(action_errors.shape[1]),
|
|
738
|
+
y=np.arange(action_errors.shape[0]),
|
|
739
|
+
colorscale='Reds', colorbar_len=0.8 / num_plots,
|
|
740
|
+
colorbar_y=1 - (i + 0.5) / num_plots
|
|
741
|
+
), row=i + 1, col=2)
|
|
742
|
+
fig.update_layout(
|
|
743
|
+
title="Values of Action-Fluents",
|
|
744
|
+
xaxis=dict(title=dict(text="Decision Epoch")),
|
|
745
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
746
|
+
height=ACTION_HEATMAP_HEIGHT * num_plots,
|
|
747
|
+
showlegend=False,
|
|
748
|
+
template="plotly_white"
|
|
749
|
+
)
|
|
750
|
+
break
|
|
751
|
+
return fig
|
|
752
|
+
|
|
753
|
+
# update the weight histograms
|
|
754
|
+
@app.callback(
|
|
755
|
+
Output('policy-params', 'figure'),
|
|
756
|
+
[Input('interval', 'n_intervals'),
|
|
757
|
+
Input('trigger-experiment-check', 'children'),
|
|
758
|
+
Input('tabs-main', 'active_tab')]
|
|
759
|
+
)
|
|
760
|
+
def update_policy_params(n, trigger, active_tab):
|
|
761
|
+
if active_tab != 'tab-policy': return dash.no_update
|
|
762
|
+
fig = go.Figure()
|
|
763
|
+
fig.update_layout(template='plotly_white')
|
|
764
|
+
for (row, checked) in self.checked.copy().items():
|
|
765
|
+
policy_params = self.policy_params[row]
|
|
766
|
+
policy_params_ticks = self.policy_params_ticks[row]
|
|
767
|
+
if checked and policy_params is not None and policy_params:
|
|
768
|
+
titles = []
|
|
769
|
+
for (layer_name, layer_params) in policy_params[0].items():
|
|
770
|
+
if isinstance(layer_params, dict):
|
|
771
|
+
for weight_name in layer_params:
|
|
772
|
+
titles.append(f'{layer_name}/{weight_name}')
|
|
773
|
+
|
|
774
|
+
n_rows = math.ceil(len(policy_params) / POLICY_DIST_PLOTS_PER_ROW)
|
|
775
|
+
fig = make_subplots(
|
|
776
|
+
rows=n_rows, cols=POLICY_DIST_PLOTS_PER_ROW,
|
|
777
|
+
shared_xaxes=True,
|
|
778
|
+
subplot_titles=titles
|
|
779
|
+
)
|
|
780
|
+
colors = pc.sample_colorscale(
|
|
781
|
+
pc.get_colorscale('Blues'),
|
|
782
|
+
np.linspace(0.1, 1, len(policy_params))
|
|
783
|
+
)[::-1]
|
|
784
|
+
|
|
785
|
+
for (it, (tick, policy_params_t)) in enumerate(
|
|
786
|
+
zip(policy_params_ticks[::-1], policy_params[::-1])):
|
|
787
|
+
r, c = 1, 1
|
|
788
|
+
for (layer_name, layer_params) in policy_params_t.items():
|
|
789
|
+
if isinstance(layer_params, dict):
|
|
790
|
+
for (weight_name, weight_values) in layer_params.items():
|
|
791
|
+
if r <= n_rows:
|
|
792
|
+
fig.add_trace(go.Violin(
|
|
793
|
+
x=np.ravel(weight_values),
|
|
794
|
+
line_color=colors[it], name=f'{tick}'
|
|
795
|
+
), row=r, col=c)
|
|
796
|
+
c += 1
|
|
797
|
+
if c > POLICY_DIST_PLOTS_PER_ROW:
|
|
798
|
+
r += 1
|
|
799
|
+
c = 1
|
|
800
|
+
fig.update_traces(
|
|
801
|
+
orientation='h', side='positive', width=3, points=False)
|
|
802
|
+
fig.update_layout(
|
|
803
|
+
title="Distribution of Network Weight Parameters",
|
|
804
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
805
|
+
showlegend=False,
|
|
806
|
+
xaxis_showgrid=False, xaxis_zeroline=False,
|
|
807
|
+
height=POLICY_DIST_HEIGHT * n_rows,
|
|
808
|
+
template="plotly_white"
|
|
809
|
+
)
|
|
810
|
+
break
|
|
811
|
+
return fig
|
|
812
|
+
|
|
813
|
+
# modify viz skip rate
|
|
814
|
+
@app.callback(
|
|
815
|
+
Output("viz-skip-frequency", "data"),
|
|
816
|
+
[Input("viz-skip-1", "n_clicks"),
|
|
817
|
+
Input("viz-skip-2", "n_clicks"),
|
|
818
|
+
Input("viz-skip-3", "n_clicks"),
|
|
819
|
+
Input("viz-skip-4", "n_clicks"),
|
|
820
|
+
Input("viz-skip-5", "n_clicks"),
|
|
821
|
+
Input("viz-skip-10", "n_clicks")],
|
|
822
|
+
[State('viz-skip-frequency', 'data')]
|
|
823
|
+
)
|
|
824
|
+
def click_viz_skip_rate(v1, v2, v3, v4, v5, v10, data):
|
|
825
|
+
ctx = dash.callback_context
|
|
826
|
+
if not ctx.triggered:
|
|
827
|
+
return data
|
|
828
|
+
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
|
829
|
+
if button_id == 'viz-skip-1':
|
|
830
|
+
return 1
|
|
831
|
+
elif button_id == 'viz-skip-2':
|
|
832
|
+
return 2
|
|
833
|
+
elif button_id == 'viz-skip-3':
|
|
834
|
+
return 3
|
|
835
|
+
elif button_id == 'viz-skip-4':
|
|
836
|
+
return 4
|
|
837
|
+
elif button_id == 'viz-skip-5':
|
|
838
|
+
return 5
|
|
839
|
+
elif button_id == 'viz-skip-10':
|
|
840
|
+
return 10
|
|
841
|
+
return data
|
|
842
|
+
|
|
843
|
+
@app.callback(
|
|
844
|
+
Output('viz-skip-dropdown', 'label'),
|
|
845
|
+
[Input('viz-skip-frequency', 'data')]
|
|
846
|
+
)
|
|
847
|
+
def update_viz_skip_dropdown_text(viz_skip):
|
|
848
|
+
if viz_skip == 1:
|
|
849
|
+
return 'Render: Every Frame'
|
|
850
|
+
else:
|
|
851
|
+
return f'Render: Every {viz_skip} Frames'
|
|
852
|
+
|
|
853
|
+
# modify viz count
|
|
854
|
+
@app.callback(
|
|
855
|
+
Output("viz-num-trajectories", "data"),
|
|
856
|
+
[Input("viz-num-1", "n_clicks"),
|
|
857
|
+
Input("viz-num-2", "n_clicks"),
|
|
858
|
+
Input("viz-num-3", "n_clicks"),
|
|
859
|
+
Input("viz-num-4", "n_clicks"),
|
|
860
|
+
Input("viz-num-5", "n_clicks")],
|
|
861
|
+
[State('viz-num-trajectories', 'data')]
|
|
862
|
+
)
|
|
863
|
+
def click_viz_num_render(v1, v2, v3, v4, v5, data):
|
|
864
|
+
ctx = dash.callback_context
|
|
865
|
+
if not ctx.triggered:
|
|
866
|
+
return data
|
|
867
|
+
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
|
868
|
+
if button_id == 'viz-num-1':
|
|
869
|
+
return 1
|
|
870
|
+
elif button_id == 'viz-num-2':
|
|
871
|
+
return 2
|
|
872
|
+
elif button_id == 'viz-num-3':
|
|
873
|
+
return 3
|
|
874
|
+
elif button_id == 'viz-num-4':
|
|
875
|
+
return 4
|
|
876
|
+
elif button_id == 'viz-num-5':
|
|
877
|
+
return 5
|
|
878
|
+
return data
|
|
879
|
+
|
|
880
|
+
@app.callback(
|
|
881
|
+
Output('viz-num-dropdown', 'label'),
|
|
882
|
+
[Input('viz-num-trajectories', 'data')]
|
|
883
|
+
)
|
|
884
|
+
def update_viz_num_dropdown_text(viz_num):
|
|
885
|
+
return f'Max. Trajectories: {viz_num}'
|
|
886
|
+
|
|
887
|
+
# update the policy viz
|
|
888
|
+
@app.callback(
|
|
889
|
+
Output('policy-viz', 'figure'),
|
|
890
|
+
Input("policy-viz-button", "n_clicks"),
|
|
891
|
+
[State('viewport-sizer', 'children'),
|
|
892
|
+
State("viz-skip-frequency", "data"),
|
|
893
|
+
State("viz-num-trajectories", "data")]
|
|
894
|
+
)
|
|
895
|
+
def update_policy_viz(n_clicks, viewport_size, skip_freq, viz_num):
|
|
896
|
+
if not viewport_size: return dash.no_update
|
|
897
|
+
if not n_clicks: return dash.no_update
|
|
898
|
+
|
|
899
|
+
for (row, checked) in self.checked.copy().items():
|
|
900
|
+
viz = self.policy_viz[row]
|
|
901
|
+
if checked and viz is not None:
|
|
902
|
+
states = self.train_state_fluents[row]
|
|
903
|
+
lookahead = next(iter(states.values())).shape[1]
|
|
904
|
+
batch_idx = self.representative_trajectories(states, k=viz_num)
|
|
905
|
+
policy_viz_frames = []
|
|
906
|
+
for idx in batch_idx:
|
|
907
|
+
avg_image = 0.
|
|
908
|
+
num_image = 0
|
|
909
|
+
viz.__init__(self.rddl[row])
|
|
910
|
+
for t in range(0, lookahead, skip_freq):
|
|
911
|
+
state_t = {name: values[idx, t]
|
|
912
|
+
for (name, values) in states.items()}
|
|
913
|
+
state_t = self.rddl[row].ground_vars_with_values(state_t)
|
|
914
|
+
avg_image += np.asarray(viz.render(state_t))
|
|
915
|
+
num_image += 1
|
|
916
|
+
avg_image /= num_image
|
|
917
|
+
policy_viz_frames.append(avg_image)
|
|
918
|
+
|
|
919
|
+
subplot_width = min(
|
|
920
|
+
viewport_size['width'] // len(policy_viz_frames),
|
|
921
|
+
POLICY_STATE_VIZ_MAX_HEIGHT)
|
|
922
|
+
fig = make_subplots(
|
|
923
|
+
rows=1, cols=len(policy_viz_frames)
|
|
924
|
+
)
|
|
925
|
+
for (col, frame) in enumerate(policy_viz_frames):
|
|
926
|
+
fig.add_trace(go.Image(z=frame, hoverinfo='skip'),
|
|
927
|
+
row=1, col=1 + col)
|
|
928
|
+
fig.update_layout(
|
|
929
|
+
title="Representative Trajectories",
|
|
930
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
931
|
+
xaxis=dict(showticklabels=False),
|
|
932
|
+
yaxis=dict(showticklabels=False),
|
|
933
|
+
width=subplot_width * len(policy_viz_frames),
|
|
934
|
+
height=subplot_width * 1,
|
|
935
|
+
showlegend=False,
|
|
936
|
+
template="plotly_white"
|
|
937
|
+
)
|
|
938
|
+
return fig
|
|
939
|
+
return dash.no_update
|
|
940
|
+
|
|
941
|
+
# update the model parameter information
|
|
942
|
+
@app.callback(
|
|
943
|
+
Output('model-params-dropdown', 'children'),
|
|
944
|
+
[Input('trigger-experiment-check', 'children'),
|
|
945
|
+
Input('tabs-main', 'active_tab')]
|
|
946
|
+
)
|
|
947
|
+
def update_model_params_dropdown_create(trigger, active_tab):
|
|
948
|
+
if active_tab != 'tab-model': return dash.no_update
|
|
949
|
+
items = []
|
|
950
|
+
for (row, checked) in self.checked.copy().items():
|
|
951
|
+
if checked:
|
|
952
|
+
items = []
|
|
953
|
+
for (expr_id, expr) in self.relaxed_exprs[row].items():
|
|
954
|
+
items.append(dbc.DropdownMenuItem([
|
|
955
|
+
B(f'{expr_id}: '),
|
|
956
|
+
expr.replace('\n', ' ')[:120]
|
|
957
|
+
], id={'type': 'expr-dropdown-item', 'index': expr_id}))
|
|
958
|
+
break
|
|
959
|
+
return items
|
|
960
|
+
|
|
961
|
+
@app.callback(
|
|
962
|
+
Output('model-params-dropdown-expr', 'data'),
|
|
963
|
+
Input({'type': 'expr-dropdown-item', 'index': ALL}, 'n_clicks')
|
|
964
|
+
)
|
|
965
|
+
def update_model_params_dropdown_select(n_clicks):
|
|
966
|
+
ctx = dash.callback_context
|
|
967
|
+
if not ctx.triggered:
|
|
968
|
+
return dash.no_update
|
|
969
|
+
if not next((item for item in n_clicks if item is not None), False):
|
|
970
|
+
return dash.no_update
|
|
971
|
+
return ast.literal_eval(
|
|
972
|
+
ctx.triggered[0]['prop_id'].split('.n_clicks')[0])['index']
|
|
973
|
+
|
|
974
|
+
@app.callback(
|
|
975
|
+
Output('model-params-graph', 'figure'),
|
|
976
|
+
[Input('interval', 'n_intervals'),
|
|
977
|
+
Input('tabs-main', 'active_tab')],
|
|
978
|
+
[State('model-params-dropdown-expr', 'data')]
|
|
979
|
+
)
|
|
980
|
+
def update_model_params_graph(n, active_tab, expr_id):
|
|
981
|
+
if active_tab != 'tab-model': return dash.no_update
|
|
982
|
+
fig = go.Figure()
|
|
983
|
+
fig.update_layout(template='plotly_white')
|
|
984
|
+
if expr_id == '': return fig
|
|
985
|
+
for (row, checked) in self.checked.copy().items():
|
|
986
|
+
if checked:
|
|
987
|
+
fig.add_trace(go.Scatter(
|
|
988
|
+
x=self.xticks[row],
|
|
989
|
+
y=self.relaxed_exprs_values[row][expr_id],
|
|
990
|
+
mode='lines+markers',
|
|
991
|
+
marker=dict(size=3), line=dict(width=2)
|
|
992
|
+
))
|
|
993
|
+
fig.update_layout(
|
|
994
|
+
title=dict(text=f"Model Parameters for Expression {expr_id}"),
|
|
995
|
+
xaxis=dict(title=dict(text="Training Iteration")),
|
|
996
|
+
yaxis=dict(title=dict(text="Parameter Value")),
|
|
997
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
998
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
999
|
+
template="plotly_white"
|
|
1000
|
+
)
|
|
1001
|
+
break
|
|
1002
|
+
return fig
|
|
1003
|
+
|
|
1004
|
+
# update the model errors information for reward
|
|
1005
|
+
@app.callback(
|
|
1006
|
+
Output('model-errors-reward-graph', 'figure'),
|
|
1007
|
+
[Input('interval', 'n_intervals'),
|
|
1008
|
+
Input('trigger-experiment-check', 'children'),
|
|
1009
|
+
Input('tabs-main', 'active_tab')]
|
|
1010
|
+
)
|
|
1011
|
+
def update_model_error_reward_graph(n, trigger, active_tab):
|
|
1012
|
+
if active_tab != 'tab-model': return dash.no_update
|
|
1013
|
+
fig = go.Figure()
|
|
1014
|
+
fig.update_layout(template='plotly_white')
|
|
1015
|
+
for (row, checked) in self.checked.copy().items():
|
|
1016
|
+
if checked and row in self.train_reward_dist:
|
|
1017
|
+
data = self.train_reward_dist[row]
|
|
1018
|
+
num_epochs = data.shape[1]
|
|
1019
|
+
step = 1
|
|
1020
|
+
if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
|
|
1021
|
+
step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
|
|
1022
|
+
for epoch in range(0, num_epochs, step):
|
|
1023
|
+
fig.add_trace(go.Violin(
|
|
1024
|
+
y=self.train_reward_dist[row][:, epoch], x0=epoch,
|
|
1025
|
+
side='negative', line_color='red',
|
|
1026
|
+
name=f'Train Epoch {epoch + 1}'
|
|
1027
|
+
))
|
|
1028
|
+
fig.add_trace(go.Violin(
|
|
1029
|
+
y=self.test_reward_dist[row][:, epoch], x0=epoch,
|
|
1030
|
+
side='positive', line_color='blue',
|
|
1031
|
+
name=f'Test Epoch {epoch + 1}'
|
|
1032
|
+
))
|
|
1033
|
+
fig.update_traces(meanline_visible=True)
|
|
1034
|
+
fig.update_layout(
|
|
1035
|
+
title=dict(text="Distribution of Reward in Relaxed Model vs True Model"),
|
|
1036
|
+
xaxis=dict(title=dict(text="Decision Epoch")),
|
|
1037
|
+
yaxis=dict(title=dict(text="Reward")),
|
|
1038
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1039
|
+
violingap=0, violinmode='overlay', showlegend=False,
|
|
1040
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
1041
|
+
template="plotly_white"
|
|
1042
|
+
)
|
|
1043
|
+
break
|
|
1044
|
+
return fig
|
|
1045
|
+
|
|
1046
|
+
# update the model errors information for state
|
|
1047
|
+
@app.callback(
|
|
1048
|
+
Output('model-errors-state-dropdown', 'children'),
|
|
1049
|
+
[Input('trigger-experiment-check', 'children'),
|
|
1050
|
+
Input('tabs-main', 'active_tab')]
|
|
1051
|
+
)
|
|
1052
|
+
def update_model_errors_state_dropdown_create(trigger, active_tab):
|
|
1053
|
+
if active_tab != 'tab-model': return dash.no_update
|
|
1054
|
+
items = []
|
|
1055
|
+
for (row, checked) in self.checked.copy().items():
|
|
1056
|
+
if checked:
|
|
1057
|
+
items = []
|
|
1058
|
+
for name in self.train_state_fluents[row]:
|
|
1059
|
+
items.append(dbc.DropdownMenuItem(
|
|
1060
|
+
[name],
|
|
1061
|
+
id={'type': 'state-fluent-dropdown-item', 'index': name}
|
|
1062
|
+
))
|
|
1063
|
+
break
|
|
1064
|
+
return items
|
|
1065
|
+
|
|
1066
|
+
@app.callback(
|
|
1067
|
+
Output('model-errors-state-dropdown-selected', 'data'),
|
|
1068
|
+
Input({'type': 'state-fluent-dropdown-item', 'index': ALL}, 'n_clicks')
|
|
1069
|
+
)
|
|
1070
|
+
def update_model_errors_state_dropdown_select(n_clicks):
|
|
1071
|
+
ctx = dash.callback_context
|
|
1072
|
+
if not ctx.triggered:
|
|
1073
|
+
return dash.no_update
|
|
1074
|
+
if not next((item for item in n_clicks if item is not None), False):
|
|
1075
|
+
return dash.no_update
|
|
1076
|
+
return ast.literal_eval(
|
|
1077
|
+
ctx.triggered[0]['prop_id'].split('.n_clicks')[0])['index']
|
|
1078
|
+
|
|
1079
|
+
@app.callback(
|
|
1080
|
+
Output('model-errors-state-graph', 'figure'),
|
|
1081
|
+
[Input('interval', 'n_intervals'),
|
|
1082
|
+
Input('trigger-experiment-check', 'children'),
|
|
1083
|
+
Input('tabs-main', 'active_tab')],
|
|
1084
|
+
[State('model-errors-state-dropdown-selected', 'data')]
|
|
1085
|
+
)
|
|
1086
|
+
def update_model_errors_state_graph(n, trigger, active_tab, state):
|
|
1087
|
+
if active_tab != 'tab-model': return dash.no_update
|
|
1088
|
+
fig = go.Figure()
|
|
1089
|
+
fig.update_layout(template='plotly_white')
|
|
1090
|
+
if not state: return fig
|
|
1091
|
+
for (row, checked) in self.checked.copy().items():
|
|
1092
|
+
if checked and row in self.train_state_fluents:
|
|
1093
|
+
train_values = self.train_state_fluents[row][state]
|
|
1094
|
+
test_values = self.test_state_fluents[row][state]
|
|
1095
|
+
train_values = 1 * train_values.reshape(train_values.shape[:2] + (-1,))
|
|
1096
|
+
test_values = 1 * test_values.reshape(test_values.shape[:2] + (-1,))
|
|
1097
|
+
num_epochs, num_states = train_values.shape[1:]
|
|
1098
|
+
step = 1
|
|
1099
|
+
if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
|
|
1100
|
+
step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
|
|
1101
|
+
fig = make_subplots(
|
|
1102
|
+
rows=num_states, cols=1, shared_xaxes=True,
|
|
1103
|
+
subplot_titles=self.rddl[row].variable_groundings[state]
|
|
1104
|
+
)
|
|
1105
|
+
for istate in range(num_states):
|
|
1106
|
+
for epoch in range(0, num_epochs, step):
|
|
1107
|
+
fig.add_trace(go.Violin(
|
|
1108
|
+
y=train_values[:, epoch, istate], x0=epoch,
|
|
1109
|
+
side='negative', line_color='red',
|
|
1110
|
+
name=f'Train Epoch {epoch + 1}'
|
|
1111
|
+
), row=istate + 1, col=1)
|
|
1112
|
+
fig.add_trace(go.Violin(
|
|
1113
|
+
y=test_values[:, epoch, istate], x0=epoch,
|
|
1114
|
+
side='positive', line_color='blue',
|
|
1115
|
+
name=f'Test Epoch {epoch + 1}'
|
|
1116
|
+
), row=istate + 1, col=1)
|
|
1117
|
+
fig.update_traces(meanline_visible=True)
|
|
1118
|
+
fig.update_layout(
|
|
1119
|
+
title=dict(text=(f"Distribution of State-Fluent {state} "
|
|
1120
|
+
f"in Relaxed Model vs True Model")),
|
|
1121
|
+
xaxis=dict(title=dict(text="Decision Epoch")),
|
|
1122
|
+
yaxis=dict(title=dict(text="State-Fluent Value")),
|
|
1123
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1124
|
+
height=MODEL_STATE_ERROR_HEIGHT * num_states,
|
|
1125
|
+
violingap=0, violinmode='overlay', showlegend=False,
|
|
1126
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
1127
|
+
template="plotly_white"
|
|
1128
|
+
)
|
|
1129
|
+
break
|
|
1130
|
+
return fig
|
|
1131
|
+
|
|
1132
|
+
# update the run information
|
|
1133
|
+
@app.callback(
|
|
1134
|
+
Output('planner-info', 'children'),
|
|
1135
|
+
[Input('interval', 'n_intervals'),
|
|
1136
|
+
Input('trigger-experiment-check', 'children'),
|
|
1137
|
+
Input('tabs-main', 'active_tab')]
|
|
1138
|
+
)
|
|
1139
|
+
def update_planner_info(n, trigger, active_tab):
|
|
1140
|
+
if active_tab != 'tab-debug': return dash.no_update
|
|
1141
|
+
result = []
|
|
1142
|
+
for (row, checked) in self.checked.copy().items():
|
|
1143
|
+
if checked:
|
|
1144
|
+
result = [
|
|
1145
|
+
H4(f'Hyper-Parameters [id={row}]', className="alert-heading"),
|
|
1146
|
+
P(self.planner_info[row], style={"whiteSpace": "pre-wrap"})
|
|
1147
|
+
]
|
|
1148
|
+
break
|
|
1149
|
+
return result
|
|
1150
|
+
|
|
1151
|
+
# update the tuning result
|
|
1152
|
+
@app.callback(
|
|
1153
|
+
[Output('tuning-target-graph', 'figure'),
|
|
1154
|
+
Output('tuning-scatter-graph', 'figure'),
|
|
1155
|
+
Output('tuning-gp-mean-graph', 'figure'),
|
|
1156
|
+
Output('tuning-gp-unc-graph', 'figure')],
|
|
1157
|
+
[Input('interval', 'n_intervals'),
|
|
1158
|
+
Input('tabs-main', 'active_tab'),
|
|
1159
|
+
Input('viewport-sizer', 'children')]
|
|
1160
|
+
)
|
|
1161
|
+
def update_tuning_gp_graph(n, active_tab, viewport_size):
|
|
1162
|
+
if not self.tuning_gp_update: return dash.no_update
|
|
1163
|
+
if not viewport_size: return dash.no_update
|
|
1164
|
+
|
|
1165
|
+
# tuning target trend
|
|
1166
|
+
fig1 = go.Figure()
|
|
1167
|
+
fig1.add_trace(go.Scatter(
|
|
1168
|
+
x=np.arange(len(self.tuning_gp_targets)), y=self.tuning_gp_targets,
|
|
1169
|
+
mode='lines+markers',
|
|
1170
|
+
marker=dict(size=3), line=dict(width=2)
|
|
1171
|
+
))
|
|
1172
|
+
fig1.update_layout(
|
|
1173
|
+
title=dict(text="Target Values of Trial Points"),
|
|
1174
|
+
xaxis=dict(title=dict(text="Trial Point")),
|
|
1175
|
+
yaxis=dict(title=dict(text="Target Value")),
|
|
1176
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1177
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
1178
|
+
template="plotly_white"
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
# tuning scatter actual and predicted
|
|
1182
|
+
fig2 = go.Figure()
|
|
1183
|
+
fig2.add_trace(go.Scatter(
|
|
1184
|
+
x=self.tuning_gp_targets, y=self.tuning_gp_predicted,
|
|
1185
|
+
mode='markers', marker=dict(size=6)
|
|
1186
|
+
))
|
|
1187
|
+
fig2.add_shape(
|
|
1188
|
+
type="line",
|
|
1189
|
+
x0=np.min(self.tuning_gp_targets),
|
|
1190
|
+
y0=np.min(self.tuning_gp_targets),
|
|
1191
|
+
x1=np.max(self.tuning_gp_targets),
|
|
1192
|
+
y1=np.max(self.tuning_gp_targets),
|
|
1193
|
+
line=dict(dash="dot", color='gray')
|
|
1194
|
+
)
|
|
1195
|
+
fig2.update_layout(
|
|
1196
|
+
title=dict(text="Gaussian Process Goodness-of-Fit Plot"),
|
|
1197
|
+
xaxis=dict(title=dict(text="Actual Target Value")),
|
|
1198
|
+
yaxis=dict(title=dict(text="Predicted Target Value")),
|
|
1199
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1200
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
1201
|
+
template="plotly_white"
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
# tuning posterior plot
|
|
1205
|
+
num_cols = len(self.tuning_gp_heatmaps)
|
|
1206
|
+
num_rows = len(self.tuning_gp_heatmaps[0])
|
|
1207
|
+
fig3 = make_subplots(rows=num_rows, cols=num_cols)
|
|
1208
|
+
fig4 = make_subplots(rows=num_rows, cols=num_cols)
|
|
1209
|
+
for col, data_col in enumerate(self.tuning_gp_heatmaps):
|
|
1210
|
+
for row, data in enumerate(data_col):
|
|
1211
|
+
p1, p2, p1v, p2v, mean, std = data
|
|
1212
|
+
fig3.add_trace(go.Heatmap(
|
|
1213
|
+
z=mean, x=p1v, y=p2v, colorscale='Blues', showscale=False
|
|
1214
|
+
), row=row + 1, col=col + 1)
|
|
1215
|
+
fig4.add_trace(go.Heatmap(
|
|
1216
|
+
z=std, x=p1v, y=p2v, colorscale='Reds', showscale=False
|
|
1217
|
+
), row=row + 1, col=col + 1)
|
|
1218
|
+
fig3.add_trace(go.Scatter(
|
|
1219
|
+
x=self.tuning_gp_params[p1],
|
|
1220
|
+
y=self.tuning_gp_params[p2],
|
|
1221
|
+
opacity=1,
|
|
1222
|
+
mode='markers',
|
|
1223
|
+
marker=dict(size=5, color='green', symbol='x')
|
|
1224
|
+
), row=row + 1, col=col + 1)
|
|
1225
|
+
fig4.add_trace(go.Scatter(
|
|
1226
|
+
x=self.tuning_gp_params[p1],
|
|
1227
|
+
y=self.tuning_gp_params[p2],
|
|
1228
|
+
opacity=1,
|
|
1229
|
+
mode='markers',
|
|
1230
|
+
marker=dict(size=5, color='green', symbol='x')
|
|
1231
|
+
), row=row + 1, col=col + 1)
|
|
1232
|
+
fig3.update_xaxes(title_text=p1, row=row + 1, col=col + 1)
|
|
1233
|
+
fig4.update_xaxes(title_text=p1, row=row + 1, col=col + 1)
|
|
1234
|
+
fig3.update_yaxes(title_text=p2, row=row + 1, col=col + 1)
|
|
1235
|
+
fig4.update_yaxes(title_text=p2, row=row + 1, col=col + 1)
|
|
1236
|
+
subplot_width = min(
|
|
1237
|
+
GP_POSTERIOR_MAX_HEIGHT, viewport_size['width'] // num_cols)
|
|
1238
|
+
fig3.update_layout(
|
|
1239
|
+
title="Posterior Mean of Gaussian Process",
|
|
1240
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1241
|
+
height=subplot_width * num_rows,
|
|
1242
|
+
width=subplot_width * num_cols,
|
|
1243
|
+
autosize=False,
|
|
1244
|
+
showlegend=False,
|
|
1245
|
+
template="plotly_white"
|
|
1246
|
+
)
|
|
1247
|
+
fig4.update_layout(
|
|
1248
|
+
title="Posterior Uncertainty of Gaussian Process",
|
|
1249
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1250
|
+
height=subplot_width * num_rows,
|
|
1251
|
+
width=subplot_width * num_cols,
|
|
1252
|
+
autosize=False,
|
|
1253
|
+
showlegend=False,
|
|
1254
|
+
template="plotly_white"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
self.tuning_gp_update = False
|
|
1258
|
+
return (fig1, fig2, fig3, fig4)
|
|
1259
|
+
|
|
1260
|
+
self.app = app
|
|
1261
|
+
|
|
1262
|
+
# ==========================================================================
|
|
1263
|
+
# DASHBOARD EXECUTION
|
|
1264
|
+
# ==========================================================================
|
|
1265
|
+
|
|
1266
|
+
def launch(self, port: int=1222, daemon: bool=True) -> None:
|
|
1267
|
+
'''Launches the dashboard in a browser window.'''
|
|
1268
|
+
|
|
1269
|
+
# open the browser to the required port
|
|
1270
|
+
if not os.environ.get("WERKZEUG_RUN_MAIN"):
|
|
1271
|
+
webbrowser.open_new(f'http://127.0.0.1:{port}/')
|
|
1272
|
+
|
|
1273
|
+
# run the app in a new thread at the specified port
|
|
1274
|
+
def run_dash():
|
|
1275
|
+
self.app.run(port=port)
|
|
1276
|
+
|
|
1277
|
+
dash_thread = threading.Thread(target=run_dash)
|
|
1278
|
+
dash_thread.daemon = daemon
|
|
1279
|
+
dash_thread.start()
|
|
1280
|
+
|
|
1281
|
+
@staticmethod
|
|
1282
|
+
def get_planner_info(planner: 'JaxBackpropPlanner') -> Dict[str, Any]:
|
|
1283
|
+
'''Some additional info directly from the planner that is required by
|
|
1284
|
+
the dashboard.'''
|
|
1285
|
+
return {
|
|
1286
|
+
'rddl': planner.rddl,
|
|
1287
|
+
'string': planner.summarize_system() + str(planner),
|
|
1288
|
+
'model_parameter_info': planner.compiled.model_parameter_info(),
|
|
1289
|
+
'trace_info': planner.compiled.traced
|
|
1290
|
+
}
|
|
1291
|
+
|
|
1292
|
+
def register_experiment(self, experiment_id: str,
|
|
1293
|
+
planner_info: Dict[str, Any],
|
|
1294
|
+
key: Optional[int]=None,
|
|
1295
|
+
viz: Optional[Any]=None) -> str:
|
|
1296
|
+
'''Starts monitoring a new experiment.'''
|
|
1297
|
+
|
|
1298
|
+
# make sure experiment id does not exist
|
|
1299
|
+
if experiment_id is None:
|
|
1300
|
+
experiment_id = len(self.xticks) + 1
|
|
1301
|
+
if experiment_id in self.xticks:
|
|
1302
|
+
raise ValueError(f'An experiment with id {experiment_id} '
|
|
1303
|
+
'was already created.')
|
|
1304
|
+
|
|
1305
|
+
self.timestamps[experiment_id] = datetime.fromtimestamp(
|
|
1306
|
+
time.time()).strftime('%Y-%m-%d %H:%M:%S')
|
|
1307
|
+
self.duration[experiment_id] = 0
|
|
1308
|
+
self.seeds[experiment_id] = key
|
|
1309
|
+
self.status[experiment_id] = 'N/A'
|
|
1310
|
+
self.progress[experiment_id] = 0
|
|
1311
|
+
self.warnings = []
|
|
1312
|
+
self.rddl[experiment_id] = planner_info['rddl']
|
|
1313
|
+
self.planner_info[experiment_id] = planner_info['string']
|
|
1314
|
+
self.checked[experiment_id] = False
|
|
1315
|
+
|
|
1316
|
+
self.xticks[experiment_id] = []
|
|
1317
|
+
self.train_return[experiment_id] = []
|
|
1318
|
+
self.test_return[experiment_id] = []
|
|
1319
|
+
self.return_dist_ticks[experiment_id] = []
|
|
1320
|
+
self.return_dist_last_progress[experiment_id] = 0
|
|
1321
|
+
self.return_dist[experiment_id] = []
|
|
1322
|
+
self.action_output[experiment_id] = None
|
|
1323
|
+
self.policy_params[experiment_id] = []
|
|
1324
|
+
self.policy_params_ticks[experiment_id] = []
|
|
1325
|
+
self.policy_params_last_progress[experiment_id] = 0
|
|
1326
|
+
self.policy_viz[experiment_id] = viz
|
|
1327
|
+
|
|
1328
|
+
decompiler = RDDLDecompiler()
|
|
1329
|
+
self.relaxed_exprs[experiment_id] = {}
|
|
1330
|
+
self.relaxed_exprs_values[experiment_id] = {}
|
|
1331
|
+
for info in planner_info['model_parameter_info'].values():
|
|
1332
|
+
expr = planner_info['trace_info'].lookup(info['id'])
|
|
1333
|
+
compiled_expr = decompiler.decompile_expr(expr)
|
|
1334
|
+
self.relaxed_exprs[experiment_id][info['id']] = compiled_expr
|
|
1335
|
+
self.relaxed_exprs_values[experiment_id][info['id']] = []
|
|
1336
|
+
|
|
1337
|
+
return experiment_id
|
|
1338
|
+
|
|
1339
|
+
@staticmethod
|
|
1340
|
+
def representative_trajectories(trajectories, k, max_iter=300):
|
|
1341
|
+
n = next(iter(trajectories.values())).shape[0]
|
|
1342
|
+
points = np.concatenate([
|
|
1343
|
+
np.reshape(1. * values, (n, -1))
|
|
1344
|
+
for values in trajectories.values()
|
|
1345
|
+
], axis=1)
|
|
1346
|
+
|
|
1347
|
+
k = min(k, n)
|
|
1348
|
+
centroids = points[np.random.choice(n, k, replace=False)]
|
|
1349
|
+
for _ in range(max_iter):
|
|
1350
|
+
distances = np.linalg.norm(
|
|
1351
|
+
points[:, None, :] - centroids[None, :, :], axis=-1)
|
|
1352
|
+
cluster_assignment = np.argmin(distances, axis=1)
|
|
1353
|
+
new_centroids = np.stack([
|
|
1354
|
+
np.mean(points[cluster_assignment == i], axis=0)
|
|
1355
|
+
for i in range(k)
|
|
1356
|
+
], axis=0)
|
|
1357
|
+
if np.allclose(new_centroids, centroids):
|
|
1358
|
+
break
|
|
1359
|
+
centroids = new_centroids
|
|
1360
|
+
return np.unique(np.argmin(distances, axis=0))
|
|
1361
|
+
|
|
1362
|
+
def update_experiment(self, experiment_id: str, callback: Dict[str, Any]) -> None:
|
|
1363
|
+
'''Pass new information and update the dashboard for a given experiment.'''
|
|
1364
|
+
|
|
1365
|
+
# data for return curves
|
|
1366
|
+
iteration = callback['iteration']
|
|
1367
|
+
self.xticks[experiment_id].append(iteration)
|
|
1368
|
+
self.train_return[experiment_id].append(callback['train_return'])
|
|
1369
|
+
self.test_return[experiment_id].append(callback['best_return'])
|
|
1370
|
+
|
|
1371
|
+
# data for return distributions
|
|
1372
|
+
progress = callback['progress']
|
|
1373
|
+
if progress - self.return_dist_last_progress[experiment_id] \
|
|
1374
|
+
>= PROGRESS_FOR_NEXT_RETURN_DIST:
|
|
1375
|
+
self.return_dist_ticks[experiment_id].append(iteration)
|
|
1376
|
+
self.return_dist[experiment_id].append(
|
|
1377
|
+
np.sum(np.asarray(callback['reward']), axis=1))
|
|
1378
|
+
self.return_dist_last_progress[experiment_id] = progress
|
|
1379
|
+
|
|
1380
|
+
# data for action heatmaps
|
|
1381
|
+
action_output = []
|
|
1382
|
+
rddl = self.rddl[experiment_id]
|
|
1383
|
+
for action in rddl.action_fluents:
|
|
1384
|
+
action_values = np.asarray(callback['fluents'][action])
|
|
1385
|
+
action_output.append(
|
|
1386
|
+
(action_values.reshape(action_values.shape[:2] + (-1,)),
|
|
1387
|
+
action,
|
|
1388
|
+
rddl.variable_groundings[action])
|
|
1389
|
+
)
|
|
1390
|
+
self.action_output[experiment_id] = action_output
|
|
1391
|
+
|
|
1392
|
+
# data for policy weight distributions
|
|
1393
|
+
if progress - self.policy_params_last_progress[experiment_id] \
|
|
1394
|
+
>= PROGRESS_FOR_NEXT_POLICY_DIST:
|
|
1395
|
+
self.policy_params_ticks[experiment_id].append(iteration)
|
|
1396
|
+
self.policy_params[experiment_id].append(callback['best_params'])
|
|
1397
|
+
self.policy_params_last_progress[experiment_id] = progress
|
|
1398
|
+
|
|
1399
|
+
# data for model relaxations
|
|
1400
|
+
model_params = callback['model_params']
|
|
1401
|
+
for (key, values) in model_params.items():
|
|
1402
|
+
expr_id = int(str(key).split('_')[0])
|
|
1403
|
+
self.relaxed_exprs_values[experiment_id][expr_id].append(values.item())
|
|
1404
|
+
self.train_reward_dist[experiment_id] = callback['train_log']['reward']
|
|
1405
|
+
self.test_reward_dist[experiment_id] = callback['reward']
|
|
1406
|
+
self.train_state_fluents[experiment_id] = {
|
|
1407
|
+
name: np.asarray(callback['train_log']['fluents'][name])
|
|
1408
|
+
for name in rddl.state_fluents or name in rddl.observ_fluents
|
|
1409
|
+
}
|
|
1410
|
+
self.test_state_fluents[experiment_id] = {
|
|
1411
|
+
name: np.asarray(callback['fluents'][name])
|
|
1412
|
+
for name in self.train_state_fluents[experiment_id]
|
|
1413
|
+
}
|
|
1414
|
+
|
|
1415
|
+
# update experiment table info
|
|
1416
|
+
self.status[experiment_id] = str(callback['status']).split('.')[1]
|
|
1417
|
+
self.duration[experiment_id] = callback["elapsed_time"]
|
|
1418
|
+
self.progress[experiment_id] = progress
|
|
1419
|
+
self.warnings = None
|
|
1420
|
+
|
|
1421
|
+
def update_tuning(self, optimizer: Any,
|
|
1422
|
+
bounds: Dict[str, Tuple[float, float]]) -> None:
|
|
1423
|
+
'''Updates the hyper-parameter tuning plots.'''
|
|
1424
|
+
|
|
1425
|
+
self.tuning_gp_heatmaps = []
|
|
1426
|
+
self.tuning_gp_update = False
|
|
1427
|
+
if not optimizer.res: return
|
|
1428
|
+
|
|
1429
|
+
self.tuning_gp_targets = optimizer.space.target.reshape((-1,))
|
|
1430
|
+
self.tuning_gp_predicted = \
|
|
1431
|
+
optimizer._gp.predict(optimizer.space.params).reshape((-1,))
|
|
1432
|
+
self.tuning_gp_params = {name: optimizer.space.params[:, i]
|
|
1433
|
+
for (i, name) in enumerate(optimizer.space.keys)}
|
|
1434
|
+
|
|
1435
|
+
for (i1, param1) in enumerate(optimizer.space.keys):
|
|
1436
|
+
self.tuning_gp_heatmaps.append([])
|
|
1437
|
+
for (i2, param2) in enumerate(optimizer.space.keys):
|
|
1438
|
+
if i2 > i1:
|
|
1439
|
+
|
|
1440
|
+
# Generate a grid for visualization
|
|
1441
|
+
p1_values = np.linspace(*bounds[param1], 100)
|
|
1442
|
+
p2_values = np.linspace(*bounds[param2], 100)
|
|
1443
|
+
P1, P2 = np.meshgrid(p1_values, p2_values)
|
|
1444
|
+
|
|
1445
|
+
# Predict the mean and deviation of the surrogate model
|
|
1446
|
+
fixed_params = max(
|
|
1447
|
+
optimizer.res,
|
|
1448
|
+
key=lambda x: x['target'])['params'].copy()
|
|
1449
|
+
fixed_params.pop(param1)
|
|
1450
|
+
fixed_params.pop(param2)
|
|
1451
|
+
param_grid = []
|
|
1452
|
+
for p1, p2 in zip(np.ravel(P1), np.ravel(P2)):
|
|
1453
|
+
params = {param1: p1, param2: p2}
|
|
1454
|
+
params.update(fixed_params)
|
|
1455
|
+
param_grid.append(
|
|
1456
|
+
[params[key] for key in optimizer.space.keys])
|
|
1457
|
+
param_grid = np.asarray(param_grid)
|
|
1458
|
+
mean, std = optimizer._gp.predict(param_grid, return_std=True)
|
|
1459
|
+
mean = mean.reshape(P1.shape)
|
|
1460
|
+
std = std.reshape(P1.shape)
|
|
1461
|
+
self.tuning_gp_heatmaps[-1].append(
|
|
1462
|
+
(param1, param2, p1_values, p2_values, mean, std))
|
|
1463
|
+
self.tuning_gp_update = True
|