pyRDDLGym-jax 2.7__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -33
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.7.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
22
|
import numpy as np
|
|
23
|
-
from typing import Callable, Dict, Optional, Union
|
|
23
|
+
from typing import Callable, Dict, Optional, Tuple, Union
|
|
24
24
|
|
|
25
25
|
import jax
|
|
26
26
|
|
|
@@ -103,7 +103,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
103
103
|
self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
|
|
104
104
|
self.reward = jax.jit(compiled.reward)
|
|
105
105
|
jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
|
|
106
|
-
self.model_params = compiled.
|
|
106
|
+
self.model_params = compiled.model_aux['params']
|
|
107
107
|
|
|
108
108
|
# level analysis
|
|
109
109
|
self.cpfs = []
|
|
@@ -116,6 +116,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
116
116
|
|
|
117
117
|
# initialize all fluent and non-fluent values
|
|
118
118
|
self.subs = self.init_values.copy()
|
|
119
|
+
self.fls, self.nfls = compiled.split_fluent_nonfluent(self.subs)
|
|
119
120
|
self.state = None
|
|
120
121
|
self.noop_actions = {var: values
|
|
121
122
|
for (var, values) in self.init_values.items()
|
|
@@ -142,24 +143,23 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
142
143
|
for (i, invariant) in enumerate(self.invariants):
|
|
143
144
|
loc = self.invariant_names[i]
|
|
144
145
|
sample, self.key, error, self.model_params = invariant(
|
|
145
|
-
self.
|
|
146
|
+
self.fls, self.nfls, self.model_params, self.key)
|
|
146
147
|
self.handle_error_code(error, loc)
|
|
147
148
|
if not bool(sample):
|
|
148
149
|
if not silent:
|
|
149
|
-
raise RDDLStateInvariantNotSatisfiedError(
|
|
150
|
-
f'{loc} is not satisfied.')
|
|
150
|
+
raise RDDLStateInvariantNotSatisfiedError(f'{loc} is not satisfied.')
|
|
151
151
|
return False
|
|
152
152
|
return True
|
|
153
153
|
|
|
154
154
|
def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
|
|
155
155
|
'''Throws an exception if the action preconditions are not satisfied.'''
|
|
156
|
-
|
|
157
|
-
subs.update(actions)
|
|
156
|
+
self.fls.update(actions)
|
|
157
|
+
self.subs.update(actions)
|
|
158
158
|
|
|
159
159
|
for (i, precond) in enumerate(self.preconds):
|
|
160
160
|
loc = self.precond_names[i]
|
|
161
161
|
sample, self.key, error, self.model_params = precond(
|
|
162
|
-
|
|
162
|
+
self.fls, self.nfls, self.model_params, self.key)
|
|
163
163
|
self.handle_error_code(error, loc)
|
|
164
164
|
if not bool(sample):
|
|
165
165
|
if not silent:
|
|
@@ -173,7 +173,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
173
173
|
for (i, terminal) in enumerate(self.terminals):
|
|
174
174
|
loc = self.terminal_names[i]
|
|
175
175
|
sample, self.key, error, self.model_params = terminal(
|
|
176
|
-
self.
|
|
176
|
+
self.fls, self.nfls, self.model_params, self.key)
|
|
177
177
|
self.handle_error_code(error, loc)
|
|
178
178
|
if bool(sample):
|
|
179
179
|
return True
|
|
@@ -182,24 +182,26 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
182
182
|
def sample_reward(self) -> float:
|
|
183
183
|
'''Samples the current reward given the current state and action.'''
|
|
184
184
|
reward, self.key, error, self.model_params = self.reward(
|
|
185
|
-
self.
|
|
185
|
+
self.fls, self.nfls, self.model_params, self.key)
|
|
186
186
|
self.handle_error_code(error, 'reward function')
|
|
187
187
|
return float(reward)
|
|
188
188
|
|
|
189
|
-
def step(self, actions: Args) -> Args:
|
|
189
|
+
def step(self, actions: Args) -> Tuple[Args, float, bool]:
|
|
190
190
|
'''Samples and returns the next state from the cpfs.
|
|
191
191
|
|
|
192
192
|
:param actions: a dict mapping current action fluents to their values
|
|
193
193
|
'''
|
|
194
194
|
rddl = self.rddl
|
|
195
195
|
keep_tensors = self.keep_tensors
|
|
196
|
-
subs = self.subs
|
|
196
|
+
subs, fls, nfls = self.subs, self.fls, self.nfls
|
|
197
197
|
subs.update(actions)
|
|
198
|
+
fls.update(actions)
|
|
198
199
|
|
|
199
200
|
# compute CPFs in topological order
|
|
200
201
|
for (cpf, expr, _) in self.cpfs:
|
|
201
|
-
|
|
202
|
-
|
|
202
|
+
fls[cpf], self.key, error, self.model_params = expr(
|
|
203
|
+
fls, nfls, self.model_params, self.key)
|
|
204
|
+
subs[cpf] = fls[cpf]
|
|
203
205
|
self.handle_error_code(error, f'CPF <{cpf}>')
|
|
204
206
|
|
|
205
207
|
# sample reward
|
|
@@ -210,10 +212,11 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
210
212
|
for (state, next_state) in rddl.next_state.items():
|
|
211
213
|
|
|
212
214
|
# set state = state' for the next epoch
|
|
215
|
+
fls[state] = fls[next_state]
|
|
213
216
|
subs[state] = subs[next_state]
|
|
214
217
|
|
|
215
218
|
# convert object integer to string representation
|
|
216
|
-
state_values =
|
|
219
|
+
state_values = fls[state]
|
|
217
220
|
if self.objects_as_strings:
|
|
218
221
|
ptype = rddl.variable_ranges[state]
|
|
219
222
|
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
@@ -231,7 +234,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
231
234
|
for var in rddl.observ_fluents:
|
|
232
235
|
|
|
233
236
|
# convert object integer to string representation
|
|
234
|
-
obs_values =
|
|
237
|
+
obs_values = fls[var]
|
|
235
238
|
if self.objects_as_strings:
|
|
236
239
|
ptype = rddl.variable_ranges[var]
|
|
237
240
|
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
@@ -244,7 +247,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
244
247
|
obs.update(rddl.ground_var_with_values(var, obs_values))
|
|
245
248
|
else:
|
|
246
249
|
obs = self.state
|
|
247
|
-
|
|
250
|
+
|
|
248
251
|
done = self.check_terminal_states()
|
|
249
252
|
return obs, reward, done
|
|
250
253
|
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -248,8 +248,8 @@ class JaxParameterTuning:
|
|
|
248
248
|
policy = JaxOfflineController(
|
|
249
249
|
planner=planner, key=subkey, tqdm_position=index,
|
|
250
250
|
params=best_params, train_on_reset=False)
|
|
251
|
-
total_reward = policy.evaluate(
|
|
252
|
-
|
|
251
|
+
total_reward = policy.evaluate(
|
|
252
|
+
env, episodes=rollouts_per_trial, seed=np.array(subkey)[0])['mean']
|
|
253
253
|
|
|
254
254
|
# update average reward
|
|
255
255
|
if verbose:
|
|
@@ -321,7 +321,8 @@ class JaxParameterTuning:
|
|
|
321
321
|
index: int,
|
|
322
322
|
iteration: int,
|
|
323
323
|
kwargs: Kwargs,
|
|
324
|
-
queue: object
|
|
324
|
+
queue: object,
|
|
325
|
+
show_dashboard: bool) -> Tuple[ParameterValues, float, int, int]:
|
|
325
326
|
'''A pickleable objective function to evaluate a single hyper-parameter
|
|
326
327
|
configuration.'''
|
|
327
328
|
|
|
@@ -345,7 +346,10 @@ class JaxParameterTuning:
|
|
|
345
346
|
planner_args, _, train_args = load_config_from_string(config_string)
|
|
346
347
|
|
|
347
348
|
# remove keywords that should not be in the tuner
|
|
348
|
-
|
|
349
|
+
if show_dashboard:
|
|
350
|
+
planner_args['dashboard'] = True
|
|
351
|
+
else:
|
|
352
|
+
planner_args['dashboard'] = None
|
|
349
353
|
planner_args.pop('parallel_updates', None)
|
|
350
354
|
|
|
351
355
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
@@ -353,6 +357,7 @@ class JaxParameterTuning:
|
|
|
353
357
|
|
|
354
358
|
# run planning algorithm
|
|
355
359
|
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
360
|
+
planner.dashboard = None
|
|
356
361
|
if online:
|
|
357
362
|
average_reward = JaxParameterTuning.online_trials(
|
|
358
363
|
env, planner, train_args, key, iteration, index, num_trials,
|
|
@@ -482,7 +487,7 @@ class JaxParameterTuning:
|
|
|
482
487
|
# assign jobs to worker pool
|
|
483
488
|
results = [
|
|
484
489
|
pool.apply_async(JaxParameterTuning.objective_function,
|
|
485
|
-
obj_args + (it, obj_kwargs, queue))
|
|
490
|
+
obj_args + (it, obj_kwargs, queue, show_dashboard))
|
|
486
491
|
for obj_args in zip(suggested_params, subkeys, worker_ids)
|
|
487
492
|
]
|
|
488
493
|
|
|
@@ -502,8 +507,7 @@ class JaxParameterTuning:
|
|
|
502
507
|
# extract and register the new evaluation
|
|
503
508
|
params, target, index, pid = results.pop(i).get()
|
|
504
509
|
optimizer.register(params, target)
|
|
505
|
-
optimizer._gp.fit(
|
|
506
|
-
optimizer.space.params, optimizer.space.target)
|
|
510
|
+
optimizer._gp.fit(optimizer.space.params, optimizer.space.target)
|
|
507
511
|
|
|
508
512
|
# update acquisition function and suggest a new point
|
|
509
513
|
suggested_params[index] = optimizer.suggest()
|
|
@@ -18,6 +18,8 @@ import os
|
|
|
18
18
|
from datetime import datetime
|
|
19
19
|
import math
|
|
20
20
|
import numpy as np
|
|
21
|
+
import io
|
|
22
|
+
import pickle
|
|
21
23
|
import time
|
|
22
24
|
import threading
|
|
23
25
|
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
|
@@ -29,7 +31,7 @@ log = logging.getLogger('werkzeug')
|
|
|
29
31
|
log.setLevel(logging.ERROR)
|
|
30
32
|
|
|
31
33
|
import dash
|
|
32
|
-
from dash.dcc import Interval, Graph, Store
|
|
34
|
+
from dash.dcc import Download, Interval, Graph, Store
|
|
33
35
|
from dash.dependencies import Input, Output, State, ALL
|
|
34
36
|
from dash.html import Div, B, H4, P, Hr
|
|
35
37
|
import dash_bootstrap_components as dbc
|
|
@@ -48,6 +50,7 @@ POLICY_DIST_PLOTS_PER_ROW = 6
|
|
|
48
50
|
ACTION_HEATMAP_HEIGHT = 400
|
|
49
51
|
PROGRESS_FOR_NEXT_RETURN_DIST = 2
|
|
50
52
|
PROGRESS_FOR_NEXT_POLICY_DIST = 10
|
|
53
|
+
PROGRESS_FOR_NEXT_BASIC_TIME_CURVE = 0.05
|
|
51
54
|
REWARD_ERROR_DIST_SUBPLOTS = 20
|
|
52
55
|
MODEL_STATE_ERROR_HEIGHT = 300
|
|
53
56
|
POLICY_STATE_VIZ_MAX_HEIGHT = 800
|
|
@@ -77,6 +80,7 @@ class JaxPlannerDashboard:
|
|
|
77
80
|
self.test_return = {}
|
|
78
81
|
self.train_return = {}
|
|
79
82
|
self.pgpe_return = {}
|
|
83
|
+
self.basic_time_curve_last_progress = {}
|
|
80
84
|
self.return_dist = {}
|
|
81
85
|
self.return_dist_ticks = {}
|
|
82
86
|
self.return_dist_last_progress = {}
|
|
@@ -91,7 +95,8 @@ class JaxPlannerDashboard:
|
|
|
91
95
|
self.train_reward_dist = {}
|
|
92
96
|
self.test_reward_dist = {}
|
|
93
97
|
self.train_state_fluents = {}
|
|
94
|
-
self.
|
|
98
|
+
self.train_state_output = {}
|
|
99
|
+
self.test_state_output = {}
|
|
95
100
|
|
|
96
101
|
self.tuning_gp_heatmaps = None
|
|
97
102
|
self.tuning_gp_targets = None
|
|
@@ -269,6 +274,8 @@ class JaxPlannerDashboard:
|
|
|
269
274
|
dbc.DropdownMenuItem("30s", id='30sec'),
|
|
270
275
|
dbc.DropdownMenuItem("1m", id='1min'),
|
|
271
276
|
dbc.DropdownMenuItem("5m", id='5min'),
|
|
277
|
+
dbc.DropdownMenuItem("30m", id='30min'),
|
|
278
|
+
dbc.DropdownMenuItem("1h", id='1h'),
|
|
272
279
|
dbc.DropdownMenuItem("1d", id='1day')],
|
|
273
280
|
label="Refresh: 2s",
|
|
274
281
|
id='refresh-rate-dropdown',
|
|
@@ -281,7 +288,7 @@ class JaxPlannerDashboard:
|
|
|
281
288
|
dbc.DropdownMenuItem("10", id='10pp'),
|
|
282
289
|
dbc.DropdownMenuItem("25", id='25pp'),
|
|
283
290
|
dbc.DropdownMenuItem("50", id='50pp')],
|
|
284
|
-
label="
|
|
291
|
+
label="Results Per Page: 10",
|
|
285
292
|
id='experiment-num-per-page-dropdown',
|
|
286
293
|
nav=True
|
|
287
294
|
)
|
|
@@ -328,6 +335,13 @@ class JaxPlannerDashboard:
|
|
|
328
335
|
# policy
|
|
329
336
|
dbc.Tab(dbc.Card(
|
|
330
337
|
dbc.CardBody([
|
|
338
|
+
dbc.Row([
|
|
339
|
+
dbc.Col([
|
|
340
|
+
dbc.Button('Save Policy Weights',
|
|
341
|
+
id='policy-save-button'),
|
|
342
|
+
Download(id="download-policy")
|
|
343
|
+
], width='auto')
|
|
344
|
+
]),
|
|
331
345
|
dbc.Row([
|
|
332
346
|
Graph(id='action-output'),
|
|
333
347
|
]),
|
|
@@ -506,10 +520,12 @@ class JaxPlannerDashboard:
|
|
|
506
520
|
Input("30sec", "n_clicks"),
|
|
507
521
|
Input("1min", "n_clicks"),
|
|
508
522
|
Input("5min", "n_clicks"),
|
|
523
|
+
Input("30min", "n_clicks"),
|
|
524
|
+
Input("1h", "n_clicks"),
|
|
509
525
|
Input("1day", "n_clicks")],
|
|
510
526
|
[State('refresh-interval', 'data')]
|
|
511
527
|
)
|
|
512
|
-
def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, nd, data):
|
|
528
|
+
def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, n30m, n1h, nd, data):
|
|
513
529
|
ctx = dash.callback_context
|
|
514
530
|
if not ctx.triggered:
|
|
515
531
|
return data
|
|
@@ -530,6 +546,10 @@ class JaxPlannerDashboard:
|
|
|
530
546
|
return 60000
|
|
531
547
|
elif button_id == '5min':
|
|
532
548
|
return 300000
|
|
549
|
+
elif button_id == '30min':
|
|
550
|
+
return 1800000
|
|
551
|
+
elif button_id == '1h':
|
|
552
|
+
return 3600000
|
|
533
553
|
elif button_id == '1day':
|
|
534
554
|
return 86400000
|
|
535
555
|
return data
|
|
@@ -562,8 +582,14 @@ class JaxPlannerDashboard:
|
|
|
562
582
|
return 'Refresh: 1m'
|
|
563
583
|
elif selected_interval == 300000:
|
|
564
584
|
return 'Refresh: 5m'
|
|
585
|
+
elif selected_interval == 1800000:
|
|
586
|
+
return 'Refresh: 30m'
|
|
587
|
+
elif selected_interval == 3600000:
|
|
588
|
+
return 'Refresh: 1h'
|
|
589
|
+
elif selected_interval == 86400000:
|
|
590
|
+
return 'Refresh: 1day'
|
|
565
591
|
else:
|
|
566
|
-
return 'Refresh:
|
|
592
|
+
return 'Refresh: n/a'
|
|
567
593
|
|
|
568
594
|
# update the experiments per page
|
|
569
595
|
@app.callback(
|
|
@@ -594,7 +620,7 @@ class JaxPlannerDashboard:
|
|
|
594
620
|
[Input('experiment-num-per-page', 'data')]
|
|
595
621
|
)
|
|
596
622
|
def update_experiments_per_page(selected_num):
|
|
597
|
-
return f'
|
|
623
|
+
return f'Results Per Page: {selected_num}'
|
|
598
624
|
|
|
599
625
|
# update the experiment table
|
|
600
626
|
@app.callback(
|
|
@@ -758,7 +784,7 @@ class JaxPlannerDashboard:
|
|
|
758
784
|
if checked and self.action_output[row] is not None:
|
|
759
785
|
num_plots = len(self.action_output[row])
|
|
760
786
|
titles = []
|
|
761
|
-
for
|
|
787
|
+
for act in self.action_output[row].keys():
|
|
762
788
|
titles.append(f'Values of Action-Fluents {act}')
|
|
763
789
|
titles.append(f'Std. Dev. of Action-Fluents {act}')
|
|
764
790
|
fig = make_subplots(
|
|
@@ -766,8 +792,7 @@ class JaxPlannerDashboard:
|
|
|
766
792
|
shared_xaxes=True, horizontal_spacing=0.15,
|
|
767
793
|
subplot_titles=titles
|
|
768
794
|
)
|
|
769
|
-
for (i,
|
|
770
|
-
in enumerate(self.action_output[row]):
|
|
795
|
+
for (i, action_output) in enumerate(self.action_output[row].values()):
|
|
771
796
|
action_values = np.mean(1. * action_output, axis=0).T
|
|
772
797
|
action_errors = np.std(1. * action_output, axis=0).T
|
|
773
798
|
fig.add_trace(go.Heatmap(
|
|
@@ -984,6 +1009,21 @@ class JaxPlannerDashboard:
|
|
|
984
1009
|
return fig
|
|
985
1010
|
return dash.no_update
|
|
986
1011
|
|
|
1012
|
+
# save policy button
|
|
1013
|
+
@app.callback(
|
|
1014
|
+
Output('download-policy', 'data'),
|
|
1015
|
+
Input("policy-save-button", "n_clicks"),
|
|
1016
|
+
prevent_initial_call=True
|
|
1017
|
+
)
|
|
1018
|
+
def save_policy_weights(n_clicks):
|
|
1019
|
+
for (row, checked) in self.checked.copy().items():
|
|
1020
|
+
if checked:
|
|
1021
|
+
bytes_io = io.BytesIO()
|
|
1022
|
+
pickle.dump(self.policy_params[row], bytes_io)
|
|
1023
|
+
bytes_io.seek(0)
|
|
1024
|
+
return dash.dcc.send_bytes(bytes_io.read(), "policy_params.pkl")
|
|
1025
|
+
return dash.no_update
|
|
1026
|
+
|
|
987
1027
|
# update the model parameter information
|
|
988
1028
|
@app.callback(
|
|
989
1029
|
Output('model-params-dropdown', 'children'),
|
|
@@ -1136,42 +1176,40 @@ class JaxPlannerDashboard:
|
|
|
1136
1176
|
if not state: return fig
|
|
1137
1177
|
for (row, checked) in self.checked.copy().items():
|
|
1138
1178
|
if checked and row in self.train_state_fluents:
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
train_values = 1 * train_values.reshape(train_values.shape[:2] + (-1,))
|
|
1142
|
-
test_values = 1 * test_values.reshape(test_values.shape[:2] + (-1,))
|
|
1143
|
-
num_epochs, num_states = train_values.shape[1:]
|
|
1144
|
-
step = 1
|
|
1145
|
-
if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
|
|
1146
|
-
step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
|
|
1179
|
+
titles = [f'Values of Train State-Fluents {state}',
|
|
1180
|
+
f'Values of Test State-Fluents {state}']
|
|
1147
1181
|
fig = make_subplots(
|
|
1148
|
-
rows=
|
|
1149
|
-
|
|
1182
|
+
rows=1, cols=2,
|
|
1183
|
+
shared_xaxes=True, horizontal_spacing=0.15,
|
|
1184
|
+
subplot_titles=titles
|
|
1150
1185
|
)
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
fig.
|
|
1186
|
+
train_state_output = self.train_state_output[row][state]
|
|
1187
|
+
test_state_output = self.test_state_output[row][state]
|
|
1188
|
+
train_state_values = np.mean(1. * train_state_output, axis=0).T
|
|
1189
|
+
test_state_values = np.mean(1. * test_state_output, axis=0).T
|
|
1190
|
+
fig.add_trace(go.Heatmap(
|
|
1191
|
+
z=train_state_values,
|
|
1192
|
+
x=np.arange(train_state_values.shape[1]),
|
|
1193
|
+
y=np.arange(train_state_values.shape[0]),
|
|
1194
|
+
colorscale='Blues', colorbar_x=0.45,
|
|
1195
|
+
colorbar_len=0.8 / 1,
|
|
1196
|
+
colorbar_y=1 - (0.5) / 1
|
|
1197
|
+
), row=1, col=1)
|
|
1198
|
+
fig.add_trace(go.Heatmap(
|
|
1199
|
+
z=test_state_values,
|
|
1200
|
+
x=np.arange(test_state_values.shape[1]),
|
|
1201
|
+
y=np.arange(test_state_values.shape[0]),
|
|
1202
|
+
colorscale='Blues', colorbar_len=0.8 / 1,
|
|
1203
|
+
colorbar_y=1 - (0.5) / 1
|
|
1204
|
+
), row=1, col=2)
|
|
1164
1205
|
fig.update_layout(
|
|
1165
|
-
title=
|
|
1166
|
-
f"in Relaxed Model vs True Model")),
|
|
1206
|
+
title=f"Values of State-Fluents {state}",
|
|
1167
1207
|
xaxis=dict(title=dict(text="Decision Epoch")),
|
|
1168
|
-
yaxis=dict(title=dict(text="State-Fluent Value")),
|
|
1169
1208
|
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
1170
|
-
height=
|
|
1171
|
-
|
|
1172
|
-
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
1209
|
+
height=ACTION_HEATMAP_HEIGHT * 1,
|
|
1210
|
+
showlegend=False,
|
|
1173
1211
|
template="plotly_white"
|
|
1174
|
-
)
|
|
1212
|
+
)
|
|
1175
1213
|
break
|
|
1176
1214
|
return fig
|
|
1177
1215
|
|
|
@@ -1363,6 +1401,7 @@ class JaxPlannerDashboard:
|
|
|
1363
1401
|
self.train_return[experiment_id] = []
|
|
1364
1402
|
self.test_return[experiment_id] = []
|
|
1365
1403
|
self.pgpe_return[experiment_id] = []
|
|
1404
|
+
self.basic_time_curve_last_progress[experiment_id] = 0
|
|
1366
1405
|
self.return_dist_ticks[experiment_id] = []
|
|
1367
1406
|
self.return_dist_last_progress[experiment_id] = 0
|
|
1368
1407
|
self.return_dist[experiment_id] = []
|
|
@@ -1410,64 +1449,63 @@ class JaxPlannerDashboard:
|
|
|
1410
1449
|
'''Pass new information and update the dashboard for a given experiment.'''
|
|
1411
1450
|
|
|
1412
1451
|
# data for return curves
|
|
1452
|
+
progress = callback['progress']
|
|
1413
1453
|
iteration = callback['iteration']
|
|
1414
|
-
self.
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1454
|
+
if progress - self.basic_time_curve_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_BASIC_TIME_CURVE:
|
|
1455
|
+
self.xticks[experiment_id].append(iteration)
|
|
1456
|
+
self.train_return[experiment_id].append(np.min(callback['train_return']))
|
|
1457
|
+
self.test_return[experiment_id].append(np.min(callback['best_return']))
|
|
1458
|
+
self.pgpe_return[experiment_id].append(np.min(callback['pgpe_return']))
|
|
1459
|
+
for (key, values) in callback['model_params'].items():
|
|
1460
|
+
self.relaxed_exprs_values[experiment_id][key].append(values[0])
|
|
1461
|
+
self.basic_time_curve_last_progress[experiment_id] = progress
|
|
1418
1462
|
|
|
1419
1463
|
# data for return distributions
|
|
1420
|
-
progress
|
|
1421
|
-
if progress - self.return_dist_last_progress[experiment_id] \
|
|
1422
|
-
>= PROGRESS_FOR_NEXT_RETURN_DIST:
|
|
1464
|
+
if progress - self.return_dist_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_RETURN_DIST:
|
|
1423
1465
|
self.return_dist_ticks[experiment_id].append(iteration)
|
|
1424
1466
|
self.return_dist[experiment_id].append(
|
|
1425
|
-
np.sum(np.
|
|
1467
|
+
np.sum(np.mean(callback['test_log']['reward'], axis=0), axis=1))
|
|
1426
1468
|
self.return_dist_last_progress[experiment_id] = progress
|
|
1427
1469
|
|
|
1470
|
+
# data for policy weight distributions
|
|
1471
|
+
if progress - self.policy_params_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_POLICY_DIST:
|
|
1472
|
+
self.policy_params_ticks[experiment_id].append(iteration)
|
|
1473
|
+
self.policy_params[experiment_id].append(callback['best_params'])
|
|
1474
|
+
self.policy_params_last_progress[experiment_id] = progress
|
|
1475
|
+
|
|
1428
1476
|
# data for action heatmaps
|
|
1429
|
-
action_output =
|
|
1477
|
+
action_output = {}
|
|
1430
1478
|
rddl = self.rddl[experiment_id]
|
|
1431
1479
|
for action in rddl.action_fluents:
|
|
1432
|
-
action_values = np.asarray(callback['fluents'][action])
|
|
1433
|
-
action_output.
|
|
1434
|
-
(action_values.reshape(action_values.shape[:2] + (-1,)),
|
|
1435
|
-
action,
|
|
1436
|
-
rddl.variable_groundings[action])
|
|
1437
|
-
)
|
|
1480
|
+
action_values = np.asarray(callback['test_log']['fluents'][action][0])
|
|
1481
|
+
action_output[action] = action_values.reshape(action_values.shape[:2] + (-1,))
|
|
1438
1482
|
self.action_output[experiment_id] = action_output
|
|
1439
1483
|
|
|
1440
|
-
# data for
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1484
|
+
# data for state heatmaps
|
|
1485
|
+
train_state_output = {}
|
|
1486
|
+
test_state_output = {}
|
|
1487
|
+
for state in rddl.state_fluents:
|
|
1488
|
+
state_values = np.asarray(callback['train_log']['fluents'][state][0])
|
|
1489
|
+
train_state_output[state] = state_values.reshape(state_values.shape[:2] + (-1,))
|
|
1490
|
+
state_values = np.asarray(callback['test_log']['fluents'][state][0])
|
|
1491
|
+
test_state_output[state] = state_values.reshape(state_values.shape[:2] + (-1,))
|
|
1492
|
+
self.train_state_output[experiment_id] = train_state_output
|
|
1493
|
+
self.test_state_output[experiment_id] = test_state_output
|
|
1446
1494
|
|
|
1447
|
-
# data for
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
expr_id = int(str(key).split('_')[0])
|
|
1451
|
-
self.relaxed_exprs_values[experiment_id][expr_id].append(values.item())
|
|
1452
|
-
self.train_reward_dist[experiment_id] = callback['train_log']['reward']
|
|
1453
|
-
self.test_reward_dist[experiment_id] = callback['reward']
|
|
1495
|
+
# data for reward distributions
|
|
1496
|
+
self.train_reward_dist[experiment_id] = np.mean(callback['train_log']['reward'], axis=0)
|
|
1497
|
+
self.test_reward_dist[experiment_id] = np.mean(callback['test_log']['reward'], axis=0)
|
|
1454
1498
|
self.train_state_fluents[experiment_id] = {
|
|
1455
|
-
name: np.asarray(callback['train_log']['fluents'][name])
|
|
1499
|
+
name: np.asarray(callback['train_log']['fluents'][name][0])
|
|
1456
1500
|
for name in rddl.state_fluents
|
|
1457
1501
|
}
|
|
1458
|
-
self.test_state_fluents[experiment_id] = {
|
|
1459
|
-
name: np.asarray(callback['fluents'][name])
|
|
1460
|
-
for name in self.train_state_fluents[experiment_id]
|
|
1461
|
-
}
|
|
1462
|
-
|
|
1463
1502
|
# update experiment table info
|
|
1464
1503
|
self.status[experiment_id] = str(callback['status']).split('.')[1]
|
|
1465
1504
|
self.duration[experiment_id] = callback["elapsed_time"]
|
|
1466
|
-
self.progress[experiment_id] = progress
|
|
1505
|
+
self.progress[experiment_id] = int(progress)
|
|
1467
1506
|
self.warnings = None
|
|
1468
1507
|
|
|
1469
|
-
def update_tuning(self, optimizer: Any,
|
|
1470
|
-
bounds: Dict[str, Tuple[float, float]]) -> None:
|
|
1508
|
+
def update_tuning(self, optimizer: Any, bounds: Dict[str, Tuple[float, float]]) -> None:
|
|
1471
1509
|
'''Updates the hyper-parameter tuning plots.'''
|
|
1472
1510
|
|
|
1473
1511
|
self.tuning_gp_heatmaps = []
|
|
@@ -1475,8 +1513,7 @@ class JaxPlannerDashboard:
|
|
|
1475
1513
|
if not optimizer.res: return
|
|
1476
1514
|
|
|
1477
1515
|
self.tuning_gp_targets = optimizer.space.target.reshape((-1,))
|
|
1478
|
-
self.tuning_gp_predicted =
|
|
1479
|
-
optimizer._gp.predict(optimizer.space.params).reshape((-1,))
|
|
1516
|
+
self.tuning_gp_predicted = optimizer._gp.predict(optimizer.space.params).reshape((-1,))
|
|
1480
1517
|
self.tuning_gp_params = {name: optimizer.space.params[:, i]
|
|
1481
1518
|
for (i, name) in enumerate(optimizer.space.keys)}
|
|
1482
1519
|
|
pyRDDLGym_jax/entry_point.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
|
|
3
|
-
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
3
|
|
|
5
4
|
EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
|
|
6
5
|
|
|
@@ -47,8 +46,10 @@ def main():
|
|
|
47
46
|
# dispatch
|
|
48
47
|
args = parser.parse_args()
|
|
49
48
|
if args.jaxplan == "plan":
|
|
49
|
+
from pyRDDLGym_jax.examples import run_plan
|
|
50
50
|
run_plan.main(args.domain, args.instance, args.method, args.episodes)
|
|
51
51
|
elif args.jaxplan == "tune":
|
|
52
|
+
from pyRDDLGym_jax.examples import run_tune
|
|
52
53
|
run_tune.main(args.domain, args.instance, args.method,
|
|
53
54
|
args.trials, args.iters, args.workers, args.dashboard,
|
|
54
55
|
args.filepath)
|
|
@@ -1,17 +1,15 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=20
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [32,
|
|
7
|
+
method_kwargs={'topology': [32, 32]}
|
|
10
8
|
optimizer='rmsprop'
|
|
11
9
|
optimizer_kwargs={'learning_rate': 0.005}
|
|
12
10
|
batch_size_train=1
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
|
|
15
|
-
[
|
|
13
|
+
[Optimize]
|
|
16
14
|
key=42
|
|
17
15
|
epochs=1000
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=20
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -13,7 +11,7 @@ batch_size_train=1
|
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
rollout_horizon=30
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
train_seconds=0.5
|
|
19
17
|
print_summary=False
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=20
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -12,7 +10,8 @@ optimizer_kwargs={'learning_rate': 0.001}
|
|
|
12
10
|
batch_size_train=1
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
clip_grad=1.0
|
|
13
|
+
pgpe_kwargs={'optimizer_kwargs_mu': {'learning_rate': 0.01}, 'optimizer_kwargs_sigma': {'learning_rate': 0.01}}
|
|
15
14
|
|
|
16
|
-
[
|
|
15
|
+
[Optimize]
|
|
17
16
|
key=42
|
|
18
|
-
epochs=
|
|
17
|
+
epochs=3000
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': 5}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
bernoulli_sigmoid_weight=5
|
|
4
|
+
sigmoid_weight=5
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxDeepReactivePolicy'
|
|
9
8
|
method_kwargs={'topology': [64, 64]}
|
|
10
9
|
optimizer='rmsprop'
|
|
@@ -12,7 +11,7 @@ optimizer_kwargs={'learning_rate': 0.001}
|
|
|
12
11
|
batch_size_train=1
|
|
13
12
|
batch_size_test=1
|
|
14
13
|
|
|
15
|
-
[
|
|
14
|
+
[Optimize]
|
|
16
15
|
key=42
|
|
17
16
|
epochs=6000
|
|
18
|
-
train_seconds=
|
|
17
|
+
train_seconds=90
|