pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -2
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@
20
20
 
21
21
  import time
22
22
  import numpy as np
23
- from typing import Callable, Dict, Optional, Union
23
+ from typing import Callable, Dict, Optional, Tuple, Union
24
24
 
25
25
  import jax
26
26
 
@@ -103,7 +103,7 @@ class JaxRDDLSimulator(RDDLSimulator):
103
103
  self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
104
104
  self.reward = jax.jit(compiled.reward)
105
105
  jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
106
- self.model_params = compiled.model_params
106
+ self.model_params = compiled.model_aux['params']
107
107
 
108
108
  # level analysis
109
109
  self.cpfs = []
@@ -116,6 +116,7 @@ class JaxRDDLSimulator(RDDLSimulator):
116
116
 
117
117
  # initialize all fluent and non-fluent values
118
118
  self.subs = self.init_values.copy()
119
+ self.fls, self.nfls = compiled.split_fluent_nonfluent(self.subs)
119
120
  self.state = None
120
121
  self.noop_actions = {var: values
121
122
  for (var, values) in self.init_values.items()
@@ -142,24 +143,23 @@ class JaxRDDLSimulator(RDDLSimulator):
142
143
  for (i, invariant) in enumerate(self.invariants):
143
144
  loc = self.invariant_names[i]
144
145
  sample, self.key, error, self.model_params = invariant(
145
- self.subs, self.model_params, self.key)
146
+ self.fls, self.nfls, self.model_params, self.key)
146
147
  self.handle_error_code(error, loc)
147
148
  if not bool(sample):
148
149
  if not silent:
149
- raise RDDLStateInvariantNotSatisfiedError(
150
- f'{loc} is not satisfied.')
150
+ raise RDDLStateInvariantNotSatisfiedError(f'{loc} is not satisfied.')
151
151
  return False
152
152
  return True
153
153
 
154
154
  def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
155
155
  '''Throws an exception if the action preconditions are not satisfied.'''
156
- subs = self.subs
157
- subs.update(actions)
156
+ self.fls.update(actions)
157
+ self.subs.update(actions)
158
158
 
159
159
  for (i, precond) in enumerate(self.preconds):
160
160
  loc = self.precond_names[i]
161
161
  sample, self.key, error, self.model_params = precond(
162
- subs, self.model_params, self.key)
162
+ self.fls, self.nfls, self.model_params, self.key)
163
163
  self.handle_error_code(error, loc)
164
164
  if not bool(sample):
165
165
  if not silent:
@@ -173,7 +173,7 @@ class JaxRDDLSimulator(RDDLSimulator):
173
173
  for (i, terminal) in enumerate(self.terminals):
174
174
  loc = self.terminal_names[i]
175
175
  sample, self.key, error, self.model_params = terminal(
176
- self.subs, self.model_params, self.key)
176
+ self.fls, self.nfls, self.model_params, self.key)
177
177
  self.handle_error_code(error, loc)
178
178
  if bool(sample):
179
179
  return True
@@ -182,24 +182,26 @@ class JaxRDDLSimulator(RDDLSimulator):
182
182
  def sample_reward(self) -> float:
183
183
  '''Samples the current reward given the current state and action.'''
184
184
  reward, self.key, error, self.model_params = self.reward(
185
- self.subs, self.model_params, self.key)
185
+ self.fls, self.nfls, self.model_params, self.key)
186
186
  self.handle_error_code(error, 'reward function')
187
187
  return float(reward)
188
188
 
189
- def step(self, actions: Args) -> Args:
189
+ def step(self, actions: Args) -> Tuple[Args, float, bool]:
190
190
  '''Samples and returns the next state from the cpfs.
191
191
 
192
192
  :param actions: a dict mapping current action fluents to their values
193
193
  '''
194
194
  rddl = self.rddl
195
195
  keep_tensors = self.keep_tensors
196
- subs = self.subs
196
+ subs, fls, nfls = self.subs, self.fls, self.nfls
197
197
  subs.update(actions)
198
+ fls.update(actions)
198
199
 
199
200
  # compute CPFs in topological order
200
201
  for (cpf, expr, _) in self.cpfs:
201
- subs[cpf], self.key, error, self.model_params = expr(
202
- subs, self.model_params, self.key)
202
+ fls[cpf], self.key, error, self.model_params = expr(
203
+ fls, nfls, self.model_params, self.key)
204
+ subs[cpf] = fls[cpf]
203
205
  self.handle_error_code(error, f'CPF <{cpf}>')
204
206
 
205
207
  # sample reward
@@ -210,10 +212,11 @@ class JaxRDDLSimulator(RDDLSimulator):
210
212
  for (state, next_state) in rddl.next_state.items():
211
213
 
212
214
  # set state = state' for the next epoch
215
+ fls[state] = fls[next_state]
213
216
  subs[state] = subs[next_state]
214
217
 
215
218
  # convert object integer to string representation
216
- state_values = subs[state]
219
+ state_values = fls[state]
217
220
  if self.objects_as_strings:
218
221
  ptype = rddl.variable_ranges[state]
219
222
  if ptype not in RDDLValueInitializer.NUMPY_TYPES:
@@ -231,7 +234,7 @@ class JaxRDDLSimulator(RDDLSimulator):
231
234
  for var in rddl.observ_fluents:
232
235
 
233
236
  # convert object integer to string representation
234
- obs_values = subs[var]
237
+ obs_values = fls[var]
235
238
  if self.objects_as_strings:
236
239
  ptype = rddl.variable_ranges[var]
237
240
  if ptype not in RDDLValueInitializer.NUMPY_TYPES:
@@ -244,7 +247,7 @@ class JaxRDDLSimulator(RDDLSimulator):
244
247
  obs.update(rddl.ground_var_with_values(var, obs_values))
245
248
  else:
246
249
  obs = self.state
247
-
250
+
248
251
  done = self.check_terminal_states()
249
252
  return obs, reward, done
250
253
 
@@ -248,8 +248,8 @@ class JaxParameterTuning:
248
248
  policy = JaxOfflineController(
249
249
  planner=planner, key=subkey, tqdm_position=index,
250
250
  params=best_params, train_on_reset=False)
251
- total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
252
- seed=np.array(subkey)[0])['mean']
251
+ total_reward = policy.evaluate(
252
+ env, episodes=rollouts_per_trial, seed=np.array(subkey)[0])['mean']
253
253
 
254
254
  # update average reward
255
255
  if verbose:
@@ -321,7 +321,8 @@ class JaxParameterTuning:
321
321
  index: int,
322
322
  iteration: int,
323
323
  kwargs: Kwargs,
324
- queue: object) -> Tuple[ParameterValues, float, int, int]:
324
+ queue: object,
325
+ show_dashboard: bool) -> Tuple[ParameterValues, float, int, int]:
325
326
  '''A pickleable objective function to evaluate a single hyper-parameter
326
327
  configuration.'''
327
328
 
@@ -345,7 +346,10 @@ class JaxParameterTuning:
345
346
  planner_args, _, train_args = load_config_from_string(config_string)
346
347
 
347
348
  # remove keywords that should not be in the tuner
348
- train_args.pop('dashboard', None)
349
+ if show_dashboard:
350
+ planner_args['dashboard'] = True
351
+ else:
352
+ planner_args['dashboard'] = None
349
353
  planner_args.pop('parallel_updates', None)
350
354
 
351
355
  # initialize env for evaluation (need fresh copy to avoid concurrency)
@@ -353,6 +357,7 @@ class JaxParameterTuning:
353
357
 
354
358
  # run planning algorithm
355
359
  planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
360
+ planner.dashboard = None
356
361
  if online:
357
362
  average_reward = JaxParameterTuning.online_trials(
358
363
  env, planner, train_args, key, iteration, index, num_trials,
@@ -482,7 +487,7 @@ class JaxParameterTuning:
482
487
  # assign jobs to worker pool
483
488
  results = [
484
489
  pool.apply_async(JaxParameterTuning.objective_function,
485
- obj_args + (it, obj_kwargs, queue))
490
+ obj_args + (it, obj_kwargs, queue, show_dashboard))
486
491
  for obj_args in zip(suggested_params, subkeys, worker_ids)
487
492
  ]
488
493
 
@@ -502,8 +507,7 @@ class JaxParameterTuning:
502
507
  # extract and register the new evaluation
503
508
  params, target, index, pid = results.pop(i).get()
504
509
  optimizer.register(params, target)
505
- optimizer._gp.fit(
506
- optimizer.space.params, optimizer.space.target)
510
+ optimizer._gp.fit(optimizer.space.params, optimizer.space.target)
507
511
 
508
512
  # update acquisition function and suggest a new point
509
513
  suggested_params[index] = optimizer.suggest()
@@ -18,6 +18,8 @@ import os
18
18
  from datetime import datetime
19
19
  import math
20
20
  import numpy as np
21
+ import io
22
+ import pickle
21
23
  import time
22
24
  import threading
23
25
  from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
@@ -29,7 +31,7 @@ log = logging.getLogger('werkzeug')
29
31
  log.setLevel(logging.ERROR)
30
32
 
31
33
  import dash
32
- from dash.dcc import Interval, Graph, Store
34
+ from dash.dcc import Download, Interval, Graph, Store
33
35
  from dash.dependencies import Input, Output, State, ALL
34
36
  from dash.html import Div, B, H4, P, Hr
35
37
  import dash_bootstrap_components as dbc
@@ -48,6 +50,7 @@ POLICY_DIST_PLOTS_PER_ROW = 6
48
50
  ACTION_HEATMAP_HEIGHT = 400
49
51
  PROGRESS_FOR_NEXT_RETURN_DIST = 2
50
52
  PROGRESS_FOR_NEXT_POLICY_DIST = 10
53
+ PROGRESS_FOR_NEXT_BASIC_TIME_CURVE = 0.05
51
54
  REWARD_ERROR_DIST_SUBPLOTS = 20
52
55
  MODEL_STATE_ERROR_HEIGHT = 300
53
56
  POLICY_STATE_VIZ_MAX_HEIGHT = 800
@@ -77,6 +80,7 @@ class JaxPlannerDashboard:
77
80
  self.test_return = {}
78
81
  self.train_return = {}
79
82
  self.pgpe_return = {}
83
+ self.basic_time_curve_last_progress = {}
80
84
  self.return_dist = {}
81
85
  self.return_dist_ticks = {}
82
86
  self.return_dist_last_progress = {}
@@ -91,7 +95,8 @@ class JaxPlannerDashboard:
91
95
  self.train_reward_dist = {}
92
96
  self.test_reward_dist = {}
93
97
  self.train_state_fluents = {}
94
- self.test_state_fluents = {}
98
+ self.train_state_output = {}
99
+ self.test_state_output = {}
95
100
 
96
101
  self.tuning_gp_heatmaps = None
97
102
  self.tuning_gp_targets = None
@@ -269,6 +274,8 @@ class JaxPlannerDashboard:
269
274
  dbc.DropdownMenuItem("30s", id='30sec'),
270
275
  dbc.DropdownMenuItem("1m", id='1min'),
271
276
  dbc.DropdownMenuItem("5m", id='5min'),
277
+ dbc.DropdownMenuItem("30m", id='30min'),
278
+ dbc.DropdownMenuItem("1h", id='1h'),
272
279
  dbc.DropdownMenuItem("1d", id='1day')],
273
280
  label="Refresh: 2s",
274
281
  id='refresh-rate-dropdown',
@@ -281,7 +288,7 @@ class JaxPlannerDashboard:
281
288
  dbc.DropdownMenuItem("10", id='10pp'),
282
289
  dbc.DropdownMenuItem("25", id='25pp'),
283
290
  dbc.DropdownMenuItem("50", id='50pp')],
284
- label="Exp. Per Page: 10",
291
+ label="Results Per Page: 10",
285
292
  id='experiment-num-per-page-dropdown',
286
293
  nav=True
287
294
  )
@@ -328,6 +335,13 @@ class JaxPlannerDashboard:
328
335
  # policy
329
336
  dbc.Tab(dbc.Card(
330
337
  dbc.CardBody([
338
+ dbc.Row([
339
+ dbc.Col([
340
+ dbc.Button('Save Policy Weights',
341
+ id='policy-save-button'),
342
+ Download(id="download-policy")
343
+ ], width='auto')
344
+ ]),
331
345
  dbc.Row([
332
346
  Graph(id='action-output'),
333
347
  ]),
@@ -506,10 +520,12 @@ class JaxPlannerDashboard:
506
520
  Input("30sec", "n_clicks"),
507
521
  Input("1min", "n_clicks"),
508
522
  Input("5min", "n_clicks"),
523
+ Input("30min", "n_clicks"),
524
+ Input("1h", "n_clicks"),
509
525
  Input("1day", "n_clicks")],
510
526
  [State('refresh-interval', 'data')]
511
527
  )
512
- def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, nd, data):
528
+ def click_refresh_rate(n05, n1, n2, n5, n10, n30, n1m, n5m, n30m, n1h, nd, data):
513
529
  ctx = dash.callback_context
514
530
  if not ctx.triggered:
515
531
  return data
@@ -530,6 +546,10 @@ class JaxPlannerDashboard:
530
546
  return 60000
531
547
  elif button_id == '5min':
532
548
  return 300000
549
+ elif button_id == '30min':
550
+ return 1800000
551
+ elif button_id == '1h':
552
+ return 3600000
533
553
  elif button_id == '1day':
534
554
  return 86400000
535
555
  return data
@@ -562,8 +582,14 @@ class JaxPlannerDashboard:
562
582
  return 'Refresh: 1m'
563
583
  elif selected_interval == 300000:
564
584
  return 'Refresh: 5m'
585
+ elif selected_interval == 1800000:
586
+ return 'Refresh: 30m'
587
+ elif selected_interval == 3600000:
588
+ return 'Refresh: 1h'
589
+ elif selected_interval == 86400000:
590
+ return 'Refresh: 1day'
565
591
  else:
566
- return 'Refresh: 2s'
592
+ return 'Refresh: n/a'
567
593
 
568
594
  # update the experiments per page
569
595
  @app.callback(
@@ -594,7 +620,7 @@ class JaxPlannerDashboard:
594
620
  [Input('experiment-num-per-page', 'data')]
595
621
  )
596
622
  def update_experiments_per_page(selected_num):
597
- return f'Exp. Per Page: {selected_num}'
623
+ return f'Results Per Page: {selected_num}'
598
624
 
599
625
  # update the experiment table
600
626
  @app.callback(
@@ -758,7 +784,7 @@ class JaxPlannerDashboard:
758
784
  if checked and self.action_output[row] is not None:
759
785
  num_plots = len(self.action_output[row])
760
786
  titles = []
761
- for (_, act, _) in self.action_output[row]:
787
+ for act in self.action_output[row].keys():
762
788
  titles.append(f'Values of Action-Fluents {act}')
763
789
  titles.append(f'Std. Dev. of Action-Fluents {act}')
764
790
  fig = make_subplots(
@@ -766,8 +792,7 @@ class JaxPlannerDashboard:
766
792
  shared_xaxes=True, horizontal_spacing=0.15,
767
793
  subplot_titles=titles
768
794
  )
769
- for (i, (action_output, action, action_labels)) \
770
- in enumerate(self.action_output[row]):
795
+ for (i, action_output) in enumerate(self.action_output[row].values()):
771
796
  action_values = np.mean(1. * action_output, axis=0).T
772
797
  action_errors = np.std(1. * action_output, axis=0).T
773
798
  fig.add_trace(go.Heatmap(
@@ -984,6 +1009,21 @@ class JaxPlannerDashboard:
984
1009
  return fig
985
1010
  return dash.no_update
986
1011
 
1012
+ # save policy button
1013
+ @app.callback(
1014
+ Output('download-policy', 'data'),
1015
+ Input("policy-save-button", "n_clicks"),
1016
+ prevent_initial_call=True
1017
+ )
1018
+ def save_policy_weights(n_clicks):
1019
+ for (row, checked) in self.checked.copy().items():
1020
+ if checked:
1021
+ bytes_io = io.BytesIO()
1022
+ pickle.dump(self.policy_params[row], bytes_io)
1023
+ bytes_io.seek(0)
1024
+ return dash.dcc.send_bytes(bytes_io.read(), "policy_params.pkl")
1025
+ return dash.no_update
1026
+
987
1027
  # update the model parameter information
988
1028
  @app.callback(
989
1029
  Output('model-params-dropdown', 'children'),
@@ -1136,42 +1176,40 @@ class JaxPlannerDashboard:
1136
1176
  if not state: return fig
1137
1177
  for (row, checked) in self.checked.copy().items():
1138
1178
  if checked and row in self.train_state_fluents:
1139
- train_values = self.train_state_fluents[row][state]
1140
- test_values = self.test_state_fluents[row][state]
1141
- train_values = 1 * train_values.reshape(train_values.shape[:2] + (-1,))
1142
- test_values = 1 * test_values.reshape(test_values.shape[:2] + (-1,))
1143
- num_epochs, num_states = train_values.shape[1:]
1144
- step = 1
1145
- if num_epochs > REWARD_ERROR_DIST_SUBPLOTS:
1146
- step = num_epochs // REWARD_ERROR_DIST_SUBPLOTS
1179
+ titles = [f'Values of Train State-Fluents {state}',
1180
+ f'Values of Test State-Fluents {state}']
1147
1181
  fig = make_subplots(
1148
- rows=num_states, cols=1, shared_xaxes=True,
1149
- subplot_titles=self.rddl[row].variable_groundings[state]
1182
+ rows=1, cols=2,
1183
+ shared_xaxes=True, horizontal_spacing=0.15,
1184
+ subplot_titles=titles
1150
1185
  )
1151
- for istate in range(num_states):
1152
- for epoch in range(0, num_epochs, step):
1153
- fig.add_trace(go.Violin(
1154
- y=train_values[:, epoch, istate], x0=epoch,
1155
- side='negative', line_color='red',
1156
- name=f'Train Epoch {epoch + 1}'
1157
- ), row=istate + 1, col=1)
1158
- fig.add_trace(go.Violin(
1159
- y=test_values[:, epoch, istate], x0=epoch,
1160
- side='positive', line_color='blue',
1161
- name=f'Test Epoch {epoch + 1}'
1162
- ), row=istate + 1, col=1)
1163
- fig.update_traces(meanline_visible=True)
1186
+ train_state_output = self.train_state_output[row][state]
1187
+ test_state_output = self.test_state_output[row][state]
1188
+ train_state_values = np.mean(1. * train_state_output, axis=0).T
1189
+ test_state_values = np.mean(1. * test_state_output, axis=0).T
1190
+ fig.add_trace(go.Heatmap(
1191
+ z=train_state_values,
1192
+ x=np.arange(train_state_values.shape[1]),
1193
+ y=np.arange(train_state_values.shape[0]),
1194
+ colorscale='Blues', colorbar_x=0.45,
1195
+ colorbar_len=0.8 / 1,
1196
+ colorbar_y=1 - (0.5) / 1
1197
+ ), row=1, col=1)
1198
+ fig.add_trace(go.Heatmap(
1199
+ z=test_state_values,
1200
+ x=np.arange(test_state_values.shape[1]),
1201
+ y=np.arange(test_state_values.shape[0]),
1202
+ colorscale='Blues', colorbar_len=0.8 / 1,
1203
+ colorbar_y=1 - (0.5) / 1
1204
+ ), row=1, col=2)
1164
1205
  fig.update_layout(
1165
- title=dict(text=(f"Distribution of State-Fluent {state} "
1166
- f"in Relaxed Model vs True Model")),
1206
+ title=f"Values of State-Fluents {state}",
1167
1207
  xaxis=dict(title=dict(text="Decision Epoch")),
1168
- yaxis=dict(title=dict(text="State-Fluent Value")),
1169
1208
  font=dict(size=PLOT_AXES_FONT_SIZE),
1170
- height=MODEL_STATE_ERROR_HEIGHT * num_states,
1171
- violingap=0, violinmode='overlay', showlegend=False,
1172
- legend=dict(bgcolor='rgba(0,0,0,0)'),
1209
+ height=ACTION_HEATMAP_HEIGHT * 1,
1210
+ showlegend=False,
1173
1211
  template="plotly_white"
1174
- )
1212
+ )
1175
1213
  break
1176
1214
  return fig
1177
1215
 
@@ -1363,6 +1401,7 @@ class JaxPlannerDashboard:
1363
1401
  self.train_return[experiment_id] = []
1364
1402
  self.test_return[experiment_id] = []
1365
1403
  self.pgpe_return[experiment_id] = []
1404
+ self.basic_time_curve_last_progress[experiment_id] = 0
1366
1405
  self.return_dist_ticks[experiment_id] = []
1367
1406
  self.return_dist_last_progress[experiment_id] = 0
1368
1407
  self.return_dist[experiment_id] = []
@@ -1410,64 +1449,63 @@ class JaxPlannerDashboard:
1410
1449
  '''Pass new information and update the dashboard for a given experiment.'''
1411
1450
 
1412
1451
  # data for return curves
1452
+ progress = callback['progress']
1413
1453
  iteration = callback['iteration']
1414
- self.xticks[experiment_id].append(iteration)
1415
- self.train_return[experiment_id].append(callback['train_return'])
1416
- self.test_return[experiment_id].append(callback['best_return'])
1417
- self.pgpe_return[experiment_id].append(callback['pgpe_return'])
1454
+ if progress - self.basic_time_curve_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_BASIC_TIME_CURVE:
1455
+ self.xticks[experiment_id].append(iteration)
1456
+ self.train_return[experiment_id].append(np.min(callback['train_return']))
1457
+ self.test_return[experiment_id].append(np.min(callback['best_return']))
1458
+ self.pgpe_return[experiment_id].append(np.min(callback['pgpe_return']))
1459
+ for (key, values) in callback['model_params'].items():
1460
+ self.relaxed_exprs_values[experiment_id][key].append(values[0])
1461
+ self.basic_time_curve_last_progress[experiment_id] = progress
1418
1462
 
1419
1463
  # data for return distributions
1420
- progress = int(callback['progress'])
1421
- if progress - self.return_dist_last_progress[experiment_id] \
1422
- >= PROGRESS_FOR_NEXT_RETURN_DIST:
1464
+ if progress - self.return_dist_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_RETURN_DIST:
1423
1465
  self.return_dist_ticks[experiment_id].append(iteration)
1424
1466
  self.return_dist[experiment_id].append(
1425
- np.sum(np.asarray(callback['reward']), axis=1))
1467
+ np.sum(np.mean(callback['test_log']['reward'], axis=0), axis=1))
1426
1468
  self.return_dist_last_progress[experiment_id] = progress
1427
1469
 
1470
+ # data for policy weight distributions
1471
+ if progress - self.policy_params_last_progress[experiment_id] >= PROGRESS_FOR_NEXT_POLICY_DIST:
1472
+ self.policy_params_ticks[experiment_id].append(iteration)
1473
+ self.policy_params[experiment_id].append(callback['best_params'])
1474
+ self.policy_params_last_progress[experiment_id] = progress
1475
+
1428
1476
  # data for action heatmaps
1429
- action_output = []
1477
+ action_output = {}
1430
1478
  rddl = self.rddl[experiment_id]
1431
1479
  for action in rddl.action_fluents:
1432
- action_values = np.asarray(callback['fluents'][action])
1433
- action_output.append(
1434
- (action_values.reshape(action_values.shape[:2] + (-1,)),
1435
- action,
1436
- rddl.variable_groundings[action])
1437
- )
1480
+ action_values = np.asarray(callback['test_log']['fluents'][action][0])
1481
+ action_output[action] = action_values.reshape(action_values.shape[:2] + (-1,))
1438
1482
  self.action_output[experiment_id] = action_output
1439
1483
 
1440
- # data for policy weight distributions
1441
- if progress - self.policy_params_last_progress[experiment_id] \
1442
- >= PROGRESS_FOR_NEXT_POLICY_DIST:
1443
- self.policy_params_ticks[experiment_id].append(iteration)
1444
- self.policy_params[experiment_id].append(callback['best_params'])
1445
- self.policy_params_last_progress[experiment_id] = progress
1484
+ # data for state heatmaps
1485
+ train_state_output = {}
1486
+ test_state_output = {}
1487
+ for state in rddl.state_fluents:
1488
+ state_values = np.asarray(callback['train_log']['fluents'][state][0])
1489
+ train_state_output[state] = state_values.reshape(state_values.shape[:2] + (-1,))
1490
+ state_values = np.asarray(callback['test_log']['fluents'][state][0])
1491
+ test_state_output[state] = state_values.reshape(state_values.shape[:2] + (-1,))
1492
+ self.train_state_output[experiment_id] = train_state_output
1493
+ self.test_state_output[experiment_id] = test_state_output
1446
1494
 
1447
- # data for model relaxations
1448
- model_params = callback['model_params']
1449
- for (key, values) in model_params.items():
1450
- expr_id = int(str(key).split('_')[0])
1451
- self.relaxed_exprs_values[experiment_id][expr_id].append(values.item())
1452
- self.train_reward_dist[experiment_id] = callback['train_log']['reward']
1453
- self.test_reward_dist[experiment_id] = callback['reward']
1495
+ # data for reward distributions
1496
+ self.train_reward_dist[experiment_id] = np.mean(callback['train_log']['reward'], axis=0)
1497
+ self.test_reward_dist[experiment_id] = np.mean(callback['test_log']['reward'], axis=0)
1454
1498
  self.train_state_fluents[experiment_id] = {
1455
- name: np.asarray(callback['train_log']['fluents'][name])
1499
+ name: np.asarray(callback['train_log']['fluents'][name][0])
1456
1500
  for name in rddl.state_fluents
1457
1501
  }
1458
- self.test_state_fluents[experiment_id] = {
1459
- name: np.asarray(callback['fluents'][name])
1460
- for name in self.train_state_fluents[experiment_id]
1461
- }
1462
-
1463
1502
  # update experiment table info
1464
1503
  self.status[experiment_id] = str(callback['status']).split('.')[1]
1465
1504
  self.duration[experiment_id] = callback["elapsed_time"]
1466
- self.progress[experiment_id] = progress
1505
+ self.progress[experiment_id] = int(progress)
1467
1506
  self.warnings = None
1468
1507
 
1469
- def update_tuning(self, optimizer: Any,
1470
- bounds: Dict[str, Tuple[float, float]]) -> None:
1508
+ def update_tuning(self, optimizer: Any, bounds: Dict[str, Tuple[float, float]]) -> None:
1471
1509
  '''Updates the hyper-parameter tuning plots.'''
1472
1510
 
1473
1511
  self.tuning_gp_heatmaps = []
@@ -1475,8 +1513,7 @@ class JaxPlannerDashboard:
1475
1513
  if not optimizer.res: return
1476
1514
 
1477
1515
  self.tuning_gp_targets = optimizer.space.target.reshape((-1,))
1478
- self.tuning_gp_predicted = \
1479
- optimizer._gp.predict(optimizer.space.params).reshape((-1,))
1516
+ self.tuning_gp_predicted = optimizer._gp.predict(optimizer.space.params).reshape((-1,))
1480
1517
  self.tuning_gp_params = {name: optimizer.space.params[:, i]
1481
1518
  for (i, name) in enumerate(optimizer.space.keys)}
1482
1519
 
@@ -1,6 +1,5 @@
1
1
  import argparse
2
2
 
3
- from pyRDDLGym_jax.examples import run_plan, run_tune
4
3
 
5
4
  EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
6
5
 
@@ -47,8 +46,10 @@ def main():
47
46
  # dispatch
48
47
  args = parser.parse_args()
49
48
  if args.jaxplan == "plan":
49
+ from pyRDDLGym_jax.examples import run_plan
50
50
  run_plan.main(args.domain, args.instance, args.method, args.episodes)
51
51
  elif args.jaxplan == "tune":
52
+ from pyRDDLGym_jax.examples import run_tune
52
53
  run_tune.main(args.domain, args.instance, args.method,
53
54
  args.trials, args.iters, args.workers, args.dashboard,
54
55
  args.filepath)
@@ -1,17 +1,15 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=20
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [32, 16]}
7
+ method_kwargs={'topology': [32, 32]}
10
8
  optimizer='rmsprop'
11
9
  optimizer_kwargs={'learning_rate': 0.005}
12
10
  batch_size_train=1
13
11
  batch_size_test=1
14
12
 
15
- [Training]
13
+ [Optimize]
16
14
  key=42
17
15
  epochs=1000
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=20
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -13,7 +11,7 @@ batch_size_train=1
13
11
  batch_size_test=1
14
12
  rollout_horizon=30
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  train_seconds=0.5
19
17
  print_summary=False
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=20
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -12,7 +10,8 @@ optimizer_kwargs={'learning_rate': 0.001}
12
10
  batch_size_train=1
13
11
  batch_size_test=1
14
12
  clip_grad=1.0
13
+ pgpe_kwargs={'optimizer_kwargs_mu': {'learning_rate': 0.01}, 'optimizer_kwargs_sigma': {'learning_rate': 0.01}}
15
14
 
16
- [Training]
15
+ [Optimize]
17
16
  key=42
18
- epochs=5000
17
+ epochs=3000
@@ -1,10 +1,9 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 5}
4
- rounding_kwargs={'weight': 5}
5
- control_kwargs={'weight': 5}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ bernoulli_sigmoid_weight=5
4
+ sigmoid_weight=5
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxDeepReactivePolicy'
9
8
  method_kwargs={'topology': [64, 64]}
10
9
  optimizer='rmsprop'
@@ -12,7 +11,7 @@ optimizer_kwargs={'learning_rate': 0.001}
12
11
  batch_size_train=1
13
12
  batch_size_test=1
14
13
 
15
- [Training]
14
+ [Optimize]
16
15
  key=42
17
16
  epochs=6000
18
- train_seconds=60
17
+ train_seconds=90