pyRDDLGym-jax 2.4__tar.gz → 2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/PKG-INFO +13 -18
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/README.md +10 -16
- pyrddlgym_jax-2.5/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/compiler.py +8 -4
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/planner.py +144 -78
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/simulator.py +37 -13
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/tuning.py +25 -10
- pyrddlgym_jax-2.5/pyRDDLGym_jax/entry_point.py +59 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_plan.py +1 -1
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/PKG-INFO +13 -18
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/setup.py +1 -1
- pyrddlgym_jax-2.4/pyRDDLGym_jax/__init__.py +0 -1
- pyrddlgym_jax-2.4/pyRDDLGym_jax/entry_point.py +0 -27
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/LICENSE +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/logic.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/visualization.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.5
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -39,6 +39,7 @@ Dynamic: description
|
|
|
39
39
|
Dynamic: description-content-type
|
|
40
40
|
Dynamic: home-page
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: provides-extra
|
|
43
44
|
Dynamic: requires-dist
|
|
44
45
|
Dynamic: requires-python
|
|
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
116
117
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
117
118
|
|
|
118
119
|
```shell
|
|
119
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
120
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
120
121
|
```
|
|
121
122
|
|
|
122
123
|
where:
|
|
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
|
|
|
241
242
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
243
|
|
|
243
244
|
```shell
|
|
244
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
245
246
|
```
|
|
246
247
|
|
|
247
248
|
where:
|
|
@@ -251,7 +252,8 @@ where:
|
|
|
251
252
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
253
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
254
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
255
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
256
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
255
257
|
|
|
256
258
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
259
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
291
293
|
with open('path/to/config.cfg', 'r') as file:
|
|
292
294
|
config_template = file.read()
|
|
293
295
|
|
|
294
|
-
#
|
|
296
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
295
297
|
def power_10(x):
|
|
296
|
-
return 10.0 ** x
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
300
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
301
|
-
]
|
|
298
|
+
return 10.0 ** x
|
|
299
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
300
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
302
301
|
|
|
303
302
|
# build the tuner and tune
|
|
304
303
|
tuning = JaxParameterTuning(env=env,
|
|
305
|
-
config_template=config_template,
|
|
306
|
-
|
|
307
|
-
online=False,
|
|
308
|
-
eval_trials=trials,
|
|
309
|
-
num_workers=workers,
|
|
310
|
-
gp_iters=iters)
|
|
304
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
305
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
311
306
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
312
307
|
```
|
|
313
308
|
|
|
@@ -70,7 +70,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
70
70
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
71
71
|
|
|
72
72
|
```shell
|
|
73
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
73
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
74
74
|
```
|
|
75
75
|
|
|
76
76
|
where:
|
|
@@ -195,7 +195,7 @@ More documentation about this and other new features will be coming soon.
|
|
|
195
195
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
196
196
|
|
|
197
197
|
```shell
|
|
198
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
198
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
199
199
|
```
|
|
200
200
|
|
|
201
201
|
where:
|
|
@@ -205,7 +205,8 @@ where:
|
|
|
205
205
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
206
206
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
207
207
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
208
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
208
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
209
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
209
210
|
|
|
210
211
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
211
212
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -245,23 +246,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
245
246
|
with open('path/to/config.cfg', 'r') as file:
|
|
246
247
|
config_template = file.read()
|
|
247
248
|
|
|
248
|
-
#
|
|
249
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
249
250
|
def power_10(x):
|
|
250
|
-
return 10.0 ** x
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
254
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
255
|
-
]
|
|
251
|
+
return 10.0 ** x
|
|
252
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
253
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
256
254
|
|
|
257
255
|
# build the tuner and tune
|
|
258
256
|
tuning = JaxParameterTuning(env=env,
|
|
259
|
-
config_template=config_template,
|
|
260
|
-
|
|
261
|
-
online=False,
|
|
262
|
-
eval_trials=trials,
|
|
263
|
-
num_workers=workers,
|
|
264
|
-
gp_iters=iters)
|
|
257
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
258
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
265
259
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
266
260
|
```
|
|
267
261
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.5'
|
|
@@ -430,7 +430,7 @@ class JaxRDDLCompiler:
|
|
|
430
430
|
_jax_wrapped_single_step_policy,
|
|
431
431
|
in_axes=(0, None, None, None, 0, None)
|
|
432
432
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
433
|
+
model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
434
434
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
435
|
return carry, log
|
|
436
436
|
|
|
@@ -440,7 +440,7 @@ class JaxRDDLCompiler:
|
|
|
440
440
|
start = (key, policy_params, hyperparams, subs, model_params)
|
|
441
441
|
steps = jnp.arange(n_steps)
|
|
442
442
|
end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
|
|
443
|
-
log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
443
|
+
log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
444
444
|
model_params = end[-1]
|
|
445
445
|
return log, model_params
|
|
446
446
|
|
|
@@ -707,7 +707,10 @@ class JaxRDDLCompiler:
|
|
|
707
707
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
708
708
|
new_slices = [None] * len(jax_nested_expr)
|
|
709
709
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
710
|
-
|
|
710
|
+
new_slice, key, err, params = jax_expr(x, params, key)
|
|
711
|
+
if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
|
|
712
|
+
new_slice = jnp.asarray(new_slice, dtype=self.INT)
|
|
713
|
+
new_slices[i] = new_slice
|
|
711
714
|
error |= err
|
|
712
715
|
new_slices = tuple(new_slices)
|
|
713
716
|
sample = sample[new_slices]
|
|
@@ -986,7 +989,8 @@ class JaxRDDLCompiler:
|
|
|
986
989
|
sample_cases = [None] * len(jax_cases)
|
|
987
990
|
for (i, jax_case) in enumerate(jax_cases):
|
|
988
991
|
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
989
|
-
err |= err_case
|
|
992
|
+
err |= err_case
|
|
993
|
+
sample_cases = jnp.asarray(sample_cases)
|
|
990
994
|
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
991
995
|
|
|
992
996
|
# predicate (enum) is an integer - use it to extract from case array
|