pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,8 +20,7 @@ import math
20
20
  import numpy as np
21
21
  import time
22
22
  import threading
23
- from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
24
- import warnings
23
+ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
25
24
  import webbrowser
26
25
 
27
26
  # prevent endless console prints
@@ -32,7 +31,7 @@ log.setLevel(logging.ERROR)
32
31
  import dash
33
32
  from dash.dcc import Interval, Graph, Store
34
33
  from dash.dependencies import Input, Output, State, ALL
35
- from dash.html import Div, B, H4, P, Img, Hr
34
+ from dash.html import Div, B, H4, P, Hr
36
35
  import dash_bootstrap_components as dbc
37
36
 
38
37
  import plotly.colors as pc
@@ -53,6 +52,7 @@ REWARD_ERROR_DIST_SUBPLOTS = 20
53
52
  MODEL_STATE_ERROR_HEIGHT = 300
54
53
  POLICY_STATE_VIZ_MAX_HEIGHT = 800
55
54
  GP_POSTERIOR_MAX_HEIGHT = 800
55
+ GP_POSTERIOR_PIXELS = 100
56
56
 
57
57
  PLOT_AXES_FONT_SIZE = 11
58
58
  EXPERIMENT_ENTRY_FONT_SIZE = 14
@@ -1417,7 +1417,7 @@ class JaxPlannerDashboard:
1417
1417
  self.pgpe_return[experiment_id].append(callback['pgpe_return'])
1418
1418
 
1419
1419
  # data for return distributions
1420
- progress = callback['progress']
1420
+ progress = int(callback['progress'])
1421
1421
  if progress - self.return_dist_last_progress[experiment_id] \
1422
1422
  >= PROGRESS_FOR_NEXT_RETURN_DIST:
1423
1423
  self.return_dist_ticks[experiment_id].append(iteration)
@@ -1486,8 +1486,8 @@ class JaxPlannerDashboard:
1486
1486
  if i2 > i1:
1487
1487
 
1488
1488
  # Generate a grid for visualization
1489
- p1_values = np.linspace(*bounds[param1], 100)
1490
- p2_values = np.linspace(*bounds[param2], 100)
1489
+ p1_values = np.linspace(*bounds[param1], GP_POSTERIOR_PIXELS)
1490
+ p2_values = np.linspace(*bounds[param2], GP_POSTERIOR_PIXELS)
1491
1491
  P1, P2 = np.meshgrid(p1_values, p2_values)
1492
1492
 
1493
1493
  # Predict the mean and deviation of the surrogate model
@@ -1500,8 +1500,7 @@ class JaxPlannerDashboard:
1500
1500
  for p1, p2 in zip(np.ravel(P1), np.ravel(P2)):
1501
1501
  params = {param1: p1, param2: p2}
1502
1502
  params.update(fixed_params)
1503
- param_grid.append(
1504
- [params[key] for key in optimizer.space.keys])
1503
+ param_grid.append([params[key] for key in optimizer.space.keys])
1505
1504
  param_grid = np.asarray(param_grid)
1506
1505
  mean, std = optimizer._gp.predict(param_grid, return_std=True)
1507
1506
  mean = mean.reshape(P1.shape)
@@ -3,7 +3,7 @@ is performed using a batched parallelized Bayesian optimization.
3
3
 
4
4
  The syntax is:
5
5
 
6
- python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]
6
+ python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]
7
7
 
8
8
  where:
9
9
  <domain> is the name of a domain located in the /Examples directory
@@ -15,6 +15,7 @@ where:
15
15
  (defaults to 20)
16
16
  <workers> is the number of parallel workers (i.e. batch size), which must
17
17
  not exceed the number of cores available on the machine (defaults to 4)
18
+ <dashboard> is whether the dashboard is displayed
18
19
  '''
19
20
  import os
20
21
  import sys
@@ -35,7 +36,7 @@ def power_10(x):
35
36
  return 10.0 ** x
36
37
 
37
38
 
38
- def main(domain, instance, method, trials=5, iters=20, workers=4):
39
+ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=False):
39
40
 
40
41
  # set up the environment
41
42
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -48,9 +49,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
48
49
 
49
50
  # map parameters in the config that will be tuned
50
51
  hyperparams = [
51
- Hyperparameter('MODEL_WEIGHT_TUNE', -1., 5., power_10),
52
+ Hyperparameter('MODEL_WEIGHT_TUNE', -1., 4., power_10),
52
53
  Hyperparameter('POLICY_WEIGHT_TUNE', -2., 2., power_10),
53
- Hyperparameter('LEARNING_RATE_TUNE', -5., 1., power_10),
54
+ Hyperparameter('LEARNING_RATE_TUNE', -5., 0., power_10),
54
55
  Hyperparameter('LAYER1_TUNE', 1, 8, power_2),
55
56
  Hyperparameter('LAYER2_TUNE', 1, 8, power_2),
56
57
  Hyperparameter('ROLLOUT_HORIZON_TUNE', 1, min(env.horizon, 100), int)
@@ -64,7 +65,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
64
65
  eval_trials=trials,
65
66
  num_workers=workers,
66
67
  gp_iters=iters)
67
- tuning.tune(key=42, log_file=f'gp_{method}_{domain}_{instance}.csv')
68
+ tuning.tune(key=42,
69
+ log_file=f'gp_{method}_{domain}_{instance}.csv',
70
+ show_dashboard=dashboard)
68
71
 
69
72
  # evaluate the agent on the best parameters
70
73
  planner_args, _, train_args = load_config_from_string(tuning.best_config)
@@ -77,7 +80,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
77
80
 
78
81
  def run_from_args(args):
79
82
  if len(args) < 3:
80
- print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
83
+ print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
81
84
  exit(1)
82
85
  if args[2] not in ['drp', 'slp', 'replan']:
83
86
  print('<method> in [drp, slp, replan]')
@@ -86,6 +89,7 @@ def run_from_args(args):
86
89
  if len(args) >= 4: kwargs['trials'] = int(args[3])
87
90
  if len(args) >= 5: kwargs['iters'] = int(args[4])
88
91
  if len(args) >= 6: kwargs['workers'] = int(args[5])
92
+ if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
89
93
  main(**kwargs)
90
94
 
91
95
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.0
3
+ Version: 2.2
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -58,18 +58,21 @@ Dynamic: summary
58
58
 
59
59
  Purpose:
60
60
 
61
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
62
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
61
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
62
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
63
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
64
+ 4. interactive dashboard for dyanmic visualization and debugging
65
+ 5. hybridization with parameter-exploring policy gradients.
63
66
 
64
67
  Some demos of solved problems by JaxPlan:
65
68
 
66
69
  <p align="middle">
67
- <img src="Images/intruders.gif" width="120" height="120" margin=0/>
68
- <img src="Images/marsrover.gif" width="120" height="120" margin=0/>
69
- <img src="Images/pong.gif" width="120" height="120" margin=0/>
70
- <img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
71
- <img src="Images/reacher.gif" width="120" height="120" margin=0/>
72
- <img src="Images/reservoir.gif" width="120" height="120" margin=0/>
70
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/intruders.gif" width="120" height="120" margin=0/>
71
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/marsrover.gif" width="120" height="120" margin=0/>
72
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/pong.gif" width="120" height="120" margin=0/>
73
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/quadcopter.gif" width="120" height="120" margin=0/>
74
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/reacher.gif" width="120" height="120" margin=0/>
75
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/reservoir.gif" width="120" height="120" margin=0/>
73
76
  </p>
74
77
 
75
78
  > [!WARNING]
@@ -219,7 +222,7 @@ Since version 1.0, JaxPlan has an optional dashboard that allows keeping track o
219
222
  and visualization of the policy or model, and other useful debugging features.
220
223
 
221
224
  <p align="middle">
222
- <img src="Images/dashboard.png" width="480" height="248" margin=0/>
225
+ <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
223
226
  </p>
224
227
 
225
228
  To run the dashboard, add the following entry to your config file:
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
235
238
 
236
239
  ## Tuning the Planner
237
240
 
238
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
239
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
241
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
+
243
+ ```shell
244
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ ```
246
+
247
+ where:
248
+ - ``domain`` is the domain identifier as specified in rddlrepository
249
+ - ``instance`` is the instance identifier
250
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
251
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+
256
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
240
258
 
241
259
  ```ini
242
260
  [Model]
@@ -260,7 +278,7 @@ train_on_reset=True
260
278
 
261
279
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
262
280
 
263
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
281
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
264
282
 
265
283
  ```python
266
284
  import pyRDDLGym
@@ -292,21 +310,7 @@ tuning = JaxParameterTuning(env=env,
292
310
  gp_iters=iters)
293
311
  tuning.tune(key=42, log_file='path/to/log.csv')
294
312
  ```
295
-
296
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
297
-
298
- ```shell
299
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
300
- ```
301
-
302
- where:
303
- - ``domain`` is the domain identifier as specified in rddlrepository
304
- - ``instance`` is the instance identifier
305
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
306
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
307
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
308
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
309
-
313
+
310
314
 
311
315
  ## Simulation
312
316
 
@@ -344,7 +348,16 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
344
348
  ```
345
349
 
346
350
  Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
347
- - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
351
+ - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs, AAAI 2022](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
348
352
  - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
353
+ - [Stochastic Planning with Lifted Symbolic Trajectory Optimization, AAAI 2019](https://ojs.aaai.org/index.php/ICAPS/article/view/3467/3335)
349
354
  - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
350
-
355
+ - [Baseline-Free Sampling in Parameter Exploring Policy Gradients: Super Symmetric PGPE, ANN 2015](https://link.springer.com/chapter/10.1007/978-3-319-09903-3_13)
356
+
357
+ The model relaxations in JaxPlan are based on the following works:
358
+ - [Poisson Variational Autoencoder, NeurIPS 2025](https://proceedings.neurips.cc/paper_files/paper/2024/file/4f3cb9576dc99d62b80726690453716f-Paper-Conference.pdf)
359
+ - [Analyzing Differentiable Fuzzy Logic Operators, AI 2022](https://www.sciencedirect.com/science/article/pii/S0004370221001533)
360
+ - [Learning with algorithmic supervision via continuous relaxations, NeurIPS 2021](https://proceedings.neurips.cc/paper_files/paper/2021/file/89ae0fe22c47d374bc9350ef99e01685-Paper.pdf)
361
+ - [Universally quantized neural compression, NeurIPS 2020](https://papers.nips.cc/paper_files/paper/2020/file/92049debbe566ca5782a3045cf300a3c-Paper.pdf)
362
+ - [Generalized Gumbel-Softmax Gradient Estimator for Generic Discrete Random Variables, 2020](https://arxiv.org/pdf/2003.01847)
363
+ - [Categorical Reparametrization with Gumbel-Softmax, ICLR 2017](https://openreview.net/pdf?id=rkE3y85ee)
@@ -1,12 +1,12 @@
1
- pyRDDLGym_jax/__init__.py,sha256=TiPG4w8nN4AzPkhugwVvZkHmAgP955NltD4QRmBLhRU,19
1
+ pyRDDLGym_jax/__init__.py,sha256=lqo7WXKfZGHPIOxgE6EWI5fGZHP2h6XrwVNNVQAUN3Q,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=Rn-aIqfgfWqu45bvCfPb9tB8RIOBVdbj-pI-V3WS2Z8,89212
5
- pyRDDLGym_jax/core/logic.py,sha256=_A6eGYtLVU3pbLAezxJVB9bnClJoaFIa2mBIDdFrqoU,39655
6
- pyRDDLGym_jax/core/planner.py,sha256=4j56l7SL7F89g2QA4nOpyhODmY0DamvxYLfCMKxJNbQ,118593
4
+ pyRDDLGym_jax/core/compiler.py,sha256=_ERueJW7GQ7S8-IezreeuLs3fNCZbQZ8j7VMUVlEt1k,82306
5
+ pyRDDLGym_jax/core/logic.py,sha256=ZeCwCLqC6BvXpRT06TvE2bfPNO6ALuMzPmUvXNzW6Uw,52278
6
+ pyRDDLGym_jax/core/planner.py,sha256=0rluBXKGNHRPEPfegOWcx9__cJHr8KjZdDJtG7i1JjI,122793
7
7
  pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
8
8
  pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
9
- pyRDDLGym_jax/core/visualization.py,sha256=XtQL1A5dQIlfeUpte-r3lNVw-GNLxj2EYUNMz7AFOtc,70359
9
+ pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
10
10
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
12
12
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -14,7 +14,7 @@ pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET
14
14
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
15
15
  pyRDDLGym_jax/examples/run_plan.py,sha256=v2AvwgIa4Ejr626vBOgWFJIQvay3IPKWno02ztIFCYc,2768
16
16
  pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
17
- pyRDDLGym_jax/examples/run_tune.py,sha256=zqrhvLR5PeWJv0NsRxDCzAPmvgPgz_1NrtM1xBy6ndU,3606
17
+ pyRDDLGym_jax/examples/run_tune.py,sha256=WbGO8RudIK-cPMAMKvI8NbFQAqkG-Blbnta3Efsep6c,3828
18
18
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
19
19
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
20
20
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
41
41
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
42
42
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
43
43
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
44
- pyRDDLGym_jax-2.0.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyRDDLGym_jax-2.0.dist-info/METADATA,sha256=ZYIe9c_Tar4WO8qQOvcUIJVMmZznPUBRaegS0DH2un8,15090
46
- pyRDDLGym_jax-2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
47
- pyRDDLGym_jax-2.0.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyRDDLGym_jax-2.0.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyRDDLGym_jax-2.0.dist-info/RECORD,,
44
+ pyrddlgym_jax-2.2.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
+ pyrddlgym_jax-2.2.dist-info/METADATA,sha256=aFNUX6uUZZHS7lPbYBTmMSH6TBiWmXbEgQNxPZNWiRI,17021
46
+ pyrddlgym_jax-2.2.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
47
+ pyrddlgym_jax-2.2.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
+ pyrddlgym_jax-2.2.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
+ pyrddlgym_jax-2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5