pyRDDLGym-jax 0.5__tar.gz → 1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/PKG-INFO +159 -103
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/README.md +155 -102
- pyrddlgym_jax-1.1/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/compiler.py +463 -592
- pyrddlgym_jax-1.1/pyRDDLGym_jax/core/logic.py +1083 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/planner.py +336 -472
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/simulator.py +7 -5
- pyrddlgym_jax-1.1/pyRDDLGym_jax/core/tuning.py +525 -0
- pyrddlgym_jax-1.1/pyRDDLGym_jax/core/visualization.py +1463 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/run_tune.py +91 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/PKG-INFO +159 -103
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/requires.txt +4 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/setup.py +3 -2
- pyrddlgym_jax-0.5/pyRDDLGym_jax/__init__.py +0 -1
- pyrddlgym_jax-0.5/pyRDDLGym_jax/core/logic.py +0 -843
- pyrddlgym_jax-0.5/pyRDDLGym_jax/core/tuning.py +0 -700
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/run_tune.py +0 -78
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/LICENSE +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.1
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -25,51 +25,77 @@ Requires-Dist: tensorflow-probability>=0.21.0
|
|
|
25
25
|
Provides-Extra: extra
|
|
26
26
|
Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
|
|
27
27
|
Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
28
|
+
Provides-Extra: dashboard
|
|
29
|
+
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
|
+
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
28
31
|
|
|
29
32
|
# pyRDDLGym-jax
|
|
30
33
|
|
|
31
|
-
|
|
34
|
+

|
|
35
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
36
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
37
|
+

|
|
38
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
|
|
35
|
-
2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
|
|
40
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
42
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
43
|
+
|
|
44
|
+
Purpose:
|
|
45
|
+
|
|
46
|
+
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
47
|
+
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
48
|
+
|
|
49
|
+
Some demos of solved problems by JaxPlan:
|
|
50
|
+
|
|
51
|
+
<p align="middle">
|
|
52
|
+
<img src="Images/intruders.gif" width="120" height="120" margin=0/>
|
|
53
|
+
<img src="Images/marsrover.gif" width="120" height="120" margin=0/>
|
|
54
|
+
<img src="Images/pong.gif" width="120" height="120" margin=0/>
|
|
55
|
+
<img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
|
|
56
|
+
<img src="Images/reacher.gif" width="120" height="120" margin=0/>
|
|
57
|
+
<img src="Images/reservoir.gif" width="120" height="120" margin=0/>
|
|
58
|
+
</p>
|
|
40
59
|
|
|
41
|
-
|
|
60
|
+
> [!WARNING]
|
|
61
|
+
> Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
|
|
62
|
+
and was moved to the individual logic components which have their own unique weight parameter assigned.
|
|
63
|
+
> Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
|
|
64
|
+
> Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
|
|
42
65
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
- [
|
|
46
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
47
|
-
- [Simulation](#simulation)
|
|
48
|
-
- [Manual Gradient Calculation](#manual-gradient-calculation)
|
|
49
|
-
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
66
|
+
> [!NOTE]
|
|
67
|
+
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
68
|
+
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
50
69
|
|
|
51
70
|
## Installation
|
|
52
71
|
|
|
53
|
-
To
|
|
54
|
-
- ``pyRDDLGym>=2.0``
|
|
55
|
-
- ``tqdm>=4.66``
|
|
56
|
-
- ``jax>=0.4.12``
|
|
57
|
-
- ``optax>=0.1.9``
|
|
58
|
-
- ``dm-haiku>=0.0.10``
|
|
59
|
-
- ``tensorflow-probability>=0.21.0``
|
|
72
|
+
To install the bare-bones version of JaxPlan with **minimum installation requirements**:
|
|
60
73
|
|
|
61
|
-
|
|
62
|
-
|
|
74
|
+
```shell
|
|
75
|
+
pip install pyRDDLGym-jax
|
|
76
|
+
```
|
|
63
77
|
|
|
64
|
-
|
|
78
|
+
To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
|
|
65
79
|
|
|
66
80
|
```shell
|
|
67
81
|
pip install pyRDDLGym-jax[extra]
|
|
68
82
|
```
|
|
69
83
|
|
|
84
|
+
(Since version 1.0) To install JaxPlan with the **visualization dashboard**:
|
|
85
|
+
|
|
86
|
+
```shell
|
|
87
|
+
pip install pyRDDLGym-jax[dashboard]
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
(Since version 1.0) To install JaxPlan with **all options**:
|
|
91
|
+
|
|
92
|
+
```shell
|
|
93
|
+
pip install pyRDDLGym-jax[extra,dashboard]
|
|
94
|
+
```
|
|
95
|
+
|
|
70
96
|
## Running from the Command Line
|
|
71
97
|
|
|
72
|
-
A basic run script is provided to run
|
|
98
|
+
A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
|
|
73
99
|
|
|
74
100
|
```shell
|
|
75
101
|
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
@@ -86,35 +112,15 @@ The ``method`` parameter supports three possible modes:
|
|
|
86
112
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
87
113
|
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
88
114
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
```shell
|
|
92
|
-
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
93
|
-
```
|
|
94
|
-
|
|
95
|
-
where:
|
|
96
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
97
|
-
- ``instance`` is the instance identifier
|
|
98
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
99
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
100
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
101
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
102
|
-
|
|
103
|
-
For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
115
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
104
116
|
|
|
105
117
|
```shell
|
|
106
118
|
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
107
119
|
```
|
|
108
120
|
|
|
109
|
-
|
|
121
|
+
## Running from Another Python Application
|
|
110
122
|
|
|
111
|
-
|
|
112
|
-
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
113
|
-
</p>
|
|
114
|
-
|
|
115
|
-
## Running from within Python
|
|
116
|
-
|
|
117
|
-
To run the Jax planner from within a Python application, refer to the following example:
|
|
123
|
+
To run JaxPlan from within a Python application, refer to the following example:
|
|
118
124
|
|
|
119
125
|
```python
|
|
120
126
|
import pyRDDLGym
|
|
@@ -144,17 +150,15 @@ The basic structure of a configuration file is provided below for a straight-lin
|
|
|
144
150
|
```ini
|
|
145
151
|
[Model]
|
|
146
152
|
logic='FuzzyLogic'
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
153
|
+
comparison_kwargs={'weight': 20}
|
|
154
|
+
rounding_kwargs={'weight': 20}
|
|
155
|
+
control_kwargs={'weight': 20}
|
|
150
156
|
|
|
151
157
|
[Optimizer]
|
|
152
158
|
method='JaxStraightLinePlan'
|
|
153
159
|
method_kwargs={}
|
|
154
160
|
optimizer='rmsprop'
|
|
155
161
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
156
|
-
batch_size_train=1
|
|
157
|
-
batch_size_test=1
|
|
158
162
|
|
|
159
163
|
[Training]
|
|
160
164
|
key=42
|
|
@@ -179,7 +183,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
|
179
183
|
```
|
|
180
184
|
|
|
181
185
|
The configuration file must then be passed to the planner during initialization.
|
|
182
|
-
For example, the [previous script here](#running-from-
|
|
186
|
+
For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
|
|
183
187
|
|
|
184
188
|
```python
|
|
185
189
|
from pyRDDLGym_jax.core.planner import load_config
|
|
@@ -193,6 +197,101 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
193
197
|
...
|
|
194
198
|
```
|
|
195
199
|
|
|
200
|
+
### JaxPlan Dashboard
|
|
201
|
+
|
|
202
|
+
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
203
|
+
and visualization of the policy or model, and other useful debugging features.
|
|
204
|
+
|
|
205
|
+
<p align="middle">
|
|
206
|
+
<img src="Images/dashboard.png" width="480" height="248" margin=0/>
|
|
207
|
+
</p>
|
|
208
|
+
|
|
209
|
+
To run the dashboard, add the following entry to your config file:
|
|
210
|
+
|
|
211
|
+
```ini
|
|
212
|
+
...
|
|
213
|
+
[Training]
|
|
214
|
+
dashboard=True
|
|
215
|
+
...
|
|
216
|
+
```
|
|
217
|
+
|
|
218
|
+
More documentation about this and other new features will be coming soon.
|
|
219
|
+
|
|
220
|
+
### Tuning the Planner
|
|
221
|
+
|
|
222
|
+
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
223
|
+
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
224
|
+
|
|
225
|
+
```ini
|
|
226
|
+
[Model]
|
|
227
|
+
logic='FuzzyLogic'
|
|
228
|
+
comparison_kwargs={'weight': TUNABLE_WEIGHT}
|
|
229
|
+
rounding_kwargs={'weight': TUNABLE_WEIGHT}
|
|
230
|
+
control_kwargs={'weight': TUNABLE_WEIGHT}
|
|
231
|
+
|
|
232
|
+
[Optimizer]
|
|
233
|
+
method='JaxStraightLinePlan'
|
|
234
|
+
method_kwargs={}
|
|
235
|
+
optimizer='rmsprop'
|
|
236
|
+
optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
|
|
237
|
+
|
|
238
|
+
[Training]
|
|
239
|
+
train_seconds=30
|
|
240
|
+
print_summary=False
|
|
241
|
+
print_progress=False
|
|
242
|
+
train_on_reset=True
|
|
243
|
+
```
|
|
244
|
+
|
|
245
|
+
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
246
|
+
|
|
247
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
248
|
+
|
|
249
|
+
```python
|
|
250
|
+
import pyRDDLGym
|
|
251
|
+
from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
|
|
252
|
+
|
|
253
|
+
# set up the environment
|
|
254
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
255
|
+
|
|
256
|
+
# load the config file template with planner settings
|
|
257
|
+
with open('path/to/config.cfg', 'r') as file:
|
|
258
|
+
config_template = file.read()
|
|
259
|
+
|
|
260
|
+
# map parameters in the config that will be tuned
|
|
261
|
+
def power_10(x):
|
|
262
|
+
return 10.0 ** x
|
|
263
|
+
|
|
264
|
+
hyperparams = [
|
|
265
|
+
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
266
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
# build the tuner and tune
|
|
270
|
+
tuning = JaxParameterTuning(env=env,
|
|
271
|
+
config_template=config_template,
|
|
272
|
+
hyperparams=hyperparams,
|
|
273
|
+
online=False,
|
|
274
|
+
eval_trials=trials,
|
|
275
|
+
num_workers=workers,
|
|
276
|
+
gp_iters=iters)
|
|
277
|
+
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
281
|
+
|
|
282
|
+
```shell
|
|
283
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
where:
|
|
287
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
288
|
+
- ``instance`` is the instance identifier
|
|
289
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
290
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
291
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
292
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
293
|
+
|
|
294
|
+
|
|
196
295
|
## Simulation
|
|
197
296
|
|
|
198
297
|
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
|
|
@@ -206,47 +305,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
|
206
305
|
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
207
306
|
|
|
208
307
|
# evaluate the random policy
|
|
209
|
-
agent = RandomAgent(action_space=env.action_space,
|
|
210
|
-
num_actions=env.max_allowed_actions)
|
|
308
|
+
agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
|
|
211
309
|
agent.evaluate(env, verbose=True, render=True)
|
|
212
310
|
```
|
|
213
311
|
|
|
214
312
|
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
215
313
|
In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
|
|
216
314
|
|
|
217
|
-
## Manual Gradient Calculation
|
|
218
|
-
|
|
219
|
-
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
|
|
220
|
-
Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
221
315
|
|
|
222
|
-
|
|
223
|
-
import pyRDDLGym
|
|
224
|
-
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
316
|
+
## Citing JaxPlan
|
|
225
317
|
|
|
226
|
-
|
|
227
|
-
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
228
|
-
|
|
229
|
-
# create the step function
|
|
230
|
-
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
231
|
-
compiled.compile()
|
|
232
|
-
step_fn = compiled.compile_transition()
|
|
233
|
-
```
|
|
234
|
-
|
|
235
|
-
This will return a JAX compiled (pure) function requiring the following inputs:
|
|
236
|
-
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
237
|
-
- ``actions`` is the dictionary of action fluent tensors
|
|
238
|
-
- ``subs`` is the dictionary of state-fluent and non-fluent tensors
|
|
239
|
-
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
240
|
-
|
|
241
|
-
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
242
|
-
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
243
|
-
|
|
244
|
-
Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
|
|
245
|
-
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
246
|
-
|
|
247
|
-
## Citing pyRDDLGym-jax
|
|
248
|
-
|
|
249
|
-
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
|
|
318
|
+
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
|
|
250
319
|
|
|
251
320
|
```
|
|
252
321
|
@inproceedings{gimelfarb2024jaxplan,
|
|
@@ -258,21 +327,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
|
|
|
258
327
|
}
|
|
259
328
|
```
|
|
260
329
|
|
|
261
|
-
The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
|
|
262
|
-
|
|
263
|
-
```
|
|
264
|
-
@inproceedings{patton2022distributional,
|
|
265
|
-
title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
|
|
266
|
-
author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
|
|
267
|
-
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
|
268
|
-
volume={36},
|
|
269
|
-
number={9},
|
|
270
|
-
pages={9894--9901},
|
|
271
|
-
year={2022}
|
|
272
|
-
}
|
|
273
|
-
```
|
|
274
|
-
|
|
275
330
|
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
|
|
331
|
+
- [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
|
|
276
332
|
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
277
333
|
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
278
334
|
|