pyRDDLGym-jax 2.4__tar.gz → 2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/PKG-INFO +13 -18
  2. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/README.md +10 -16
  3. pyrddlgym_jax-2.5/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/compiler.py +8 -4
  5. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/planner.py +144 -78
  6. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/simulator.py +37 -13
  7. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/tuning.py +25 -10
  8. pyrddlgym_jax-2.5/pyRDDLGym_jax/entry_point.py +59 -0
  9. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
  10. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
  11. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
  12. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_plan.py +1 -1
  13. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_tune.py +8 -2
  14. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/PKG-INFO +13 -18
  15. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/setup.py +1 -1
  16. pyrddlgym_jax-2.4/pyRDDLGym_jax/__init__.py +0 -1
  17. pyrddlgym_jax-2.4/pyRDDLGym_jax/entry_point.py +0 -27
  18. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/LICENSE +0 -0
  19. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/__init__.py +0 -0
  20. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  21. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  22. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/logic.py +0 -0
  23. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/visualization.py +0 -0
  24. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/__init__.py +0 -0
  25. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  26. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  27. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  28. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  29. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  30. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  32. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  33. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  34. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  35. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  36. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  37. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  38. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  39. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  40. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  41. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  42. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  43. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  44. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  45. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  46. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  47. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  48. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  49. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  50. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  51. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  52. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  53. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  54. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  55. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  56. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.4
3
+ Version: 2.5
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -39,6 +39,7 @@ Dynamic: description
39
39
  Dynamic: description-content-type
40
40
  Dynamic: home-page
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: provides-extra
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
116
117
  A basic run script is provided to train JaxPlan on any RDDL problem:
117
118
 
118
119
  ```shell
119
- jaxplan plan <domain> <instance> <method> <episodes>
120
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
120
121
  ```
121
122
 
122
123
  where:
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
241
242
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
243
 
243
244
  ```shell
244
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
245
246
  ```
246
247
 
247
248
  where:
@@ -251,7 +252,8 @@ where:
251
252
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
253
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
254
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
256
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
255
257
 
256
258
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
259
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
291
293
  with open('path/to/config.cfg', 'r') as file:
292
294
  config_template = file.read()
293
295
 
294
- # map parameters in the config that will be tuned
296
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
295
297
  def power_10(x):
296
- return 10.0 ** x
297
-
298
- hyperparams = [
299
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
300
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
301
- ]
298
+ return 10.0 ** x
299
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
300
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
302
301
 
303
302
  # build the tuner and tune
304
303
  tuning = JaxParameterTuning(env=env,
305
- config_template=config_template,
306
- hyperparams=hyperparams,
307
- online=False,
308
- eval_trials=trials,
309
- num_workers=workers,
310
- gp_iters=iters)
304
+ config_template=config_template, hyperparams=hyperparams,
305
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
311
306
  tuning.tune(key=42, log_file='path/to/log.csv')
312
307
  ```
313
308
 
@@ -70,7 +70,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
70
70
  A basic run script is provided to train JaxPlan on any RDDL problem:
71
71
 
72
72
  ```shell
73
- jaxplan plan <domain> <instance> <method> <episodes>
73
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
74
74
  ```
75
75
 
76
76
  where:
@@ -195,7 +195,7 @@ More documentation about this and other new features will be coming soon.
195
195
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
196
196
 
197
197
  ```shell
198
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
198
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
199
199
  ```
200
200
 
201
201
  where:
@@ -205,7 +205,8 @@ where:
205
205
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
206
206
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
207
207
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
208
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
208
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
209
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
209
210
 
210
211
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
211
212
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -245,23 +246,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
245
246
  with open('path/to/config.cfg', 'r') as file:
246
247
  config_template = file.read()
247
248
 
248
- # map parameters in the config that will be tuned
249
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
249
250
  def power_10(x):
250
- return 10.0 ** x
251
-
252
- hyperparams = [
253
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
254
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
255
- ]
251
+ return 10.0 ** x
252
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
253
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
256
254
 
257
255
  # build the tuner and tune
258
256
  tuning = JaxParameterTuning(env=env,
259
- config_template=config_template,
260
- hyperparams=hyperparams,
261
- online=False,
262
- eval_trials=trials,
263
- num_workers=workers,
264
- gp_iters=iters)
257
+ config_template=config_template, hyperparams=hyperparams,
258
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
265
259
  tuning.tune(key=42, log_file='path/to/log.csv')
266
260
  ```
267
261
 
@@ -0,0 +1 @@
1
+ __version__ = '2.5'
@@ -430,7 +430,7 @@ class JaxRDDLCompiler:
430
430
  _jax_wrapped_single_step_policy,
431
431
  in_axes=(0, None, None, None, 0, None)
432
432
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
433
+ model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
434
434
  carry = (key, policy_params, hyperparams, subs, model_params)
435
435
  return carry, log
436
436
 
@@ -440,7 +440,7 @@ class JaxRDDLCompiler:
440
440
  start = (key, policy_params, hyperparams, subs, model_params)
441
441
  steps = jnp.arange(n_steps)
442
442
  end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
443
- log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
443
+ log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
444
444
  model_params = end[-1]
445
445
  return log, model_params
446
446
 
@@ -707,7 +707,10 @@ class JaxRDDLCompiler:
707
707
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
708
708
  new_slices = [None] * len(jax_nested_expr)
709
709
  for (i, jax_expr) in enumerate(jax_nested_expr):
710
- new_slices[i], key, err, params = jax_expr(x, params, key)
710
+ new_slice, key, err, params = jax_expr(x, params, key)
711
+ if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
712
+ new_slice = jnp.asarray(new_slice, dtype=self.INT)
713
+ new_slices[i] = new_slice
711
714
  error |= err
712
715
  new_slices = tuple(new_slices)
713
716
  sample = sample[new_slices]
@@ -986,7 +989,8 @@ class JaxRDDLCompiler:
986
989
  sample_cases = [None] * len(jax_cases)
987
990
  for (i, jax_case) in enumerate(jax_cases):
988
991
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
989
- err |= err_case
992
+ err |= err_case
993
+ sample_cases = jnp.asarray(sample_cases)
990
994
  sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
991
995
 
992
996
  # predicate (enum) is an integer - use it to extract from case array