pyRDDLGym-jax 0.5__tar.gz → 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/PKG-INFO +153 -96
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/README.md +149 -95
- pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/compiler.py +463 -592
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/logic.py +1083 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/planner.py +329 -463
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/simulator.py +7 -5
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/tuning.py +511 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/visualization.py +1463 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/run_tune.py +91 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/PKG-INFO +153 -96
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/requires.txt +4 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/setup.py +3 -2
- pyrddlgym_jax-0.5/pyRDDLGym_jax/__init__.py +0 -1
- pyrddlgym_jax-0.5/pyRDDLGym_jax/core/logic.py +0 -843
- pyrddlgym_jax-0.5/pyRDDLGym_jax/core/tuning.py +0 -700
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/run_tune.py +0 -78
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/LICENSE +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 0
|
|
3
|
+
Version: 1.0
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -25,51 +25,78 @@ Requires-Dist: tensorflow-probability>=0.21.0
|
|
|
25
25
|
Provides-Extra: extra
|
|
26
26
|
Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
|
|
27
27
|
Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
28
|
+
Provides-Extra: dashboard
|
|
29
|
+
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
|
+
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
28
31
|
|
|
29
32
|
# pyRDDLGym-jax
|
|
30
33
|
|
|
31
|
-
|
|
34
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
+
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
|
+
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
38
|
+
|
|
39
|
+
Some demos of solved problems by JaxPlan:
|
|
40
|
+
|
|
41
|
+
<p align="middle">
|
|
42
|
+
<img src="Images/intruders.gif" width="120" height="120" margin=0/>
|
|
43
|
+
<img src="Images/marsrover.gif" width="120" height="120" margin=0/>
|
|
44
|
+
<img src="Images/pong.gif" width="120" height="120" margin=0/>
|
|
45
|
+
<img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
|
|
46
|
+
<img src="Images/reacher.gif" width="120" height="120" margin=0/>
|
|
47
|
+
<img src="Images/reservoir.gif" width="120" height="120" margin=0/>
|
|
48
|
+
</p>
|
|
49
|
+
|
|
50
|
+
> [!WARNING]
|
|
51
|
+
> Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
|
|
52
|
+
and was moved to the individual logic components which have their own unique weight parameter assigned.
|
|
53
|
+
> Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
|
|
54
|
+
> Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
|
|
36
55
|
|
|
37
56
|
> [!NOTE]
|
|
38
|
-
> While
|
|
57
|
+
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
39
58
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
40
59
|
|
|
41
60
|
## Contents
|
|
42
61
|
|
|
43
62
|
- [Installation](#installation)
|
|
44
63
|
- [Running from the Command Line](#running-from-the-command-line)
|
|
45
|
-
- [Running from
|
|
64
|
+
- [Running from Another Python Application](#running-from-another-python-application)
|
|
46
65
|
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
+
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
+
- [Tuning the Planner](#tuning-the-planner)
|
|
47
68
|
- [Simulation](#simulation)
|
|
48
|
-
- [
|
|
49
|
-
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
69
|
+
- [Citing JaxPlan](#citing-jaxplan)
|
|
50
70
|
|
|
51
71
|
## Installation
|
|
52
72
|
|
|
53
|
-
To
|
|
54
|
-
- ``pyRDDLGym>=2.0``
|
|
55
|
-
- ``tqdm>=4.66``
|
|
56
|
-
- ``jax>=0.4.12``
|
|
57
|
-
- ``optax>=0.1.9``
|
|
58
|
-
- ``dm-haiku>=0.0.10``
|
|
59
|
-
- ``tensorflow-probability>=0.21.0``
|
|
73
|
+
To install the bare-bones version of JaxPlan with **minimum installation requirements**:
|
|
60
74
|
|
|
61
|
-
|
|
62
|
-
|
|
75
|
+
```shell
|
|
76
|
+
pip install pyRDDLGym-jax
|
|
77
|
+
```
|
|
63
78
|
|
|
64
|
-
|
|
79
|
+
To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
|
|
65
80
|
|
|
66
81
|
```shell
|
|
67
82
|
pip install pyRDDLGym-jax[extra]
|
|
68
83
|
```
|
|
69
84
|
|
|
85
|
+
(Since version 1.0) To install JaxPlan with the **visualization dashboard**:
|
|
86
|
+
|
|
87
|
+
```shell
|
|
88
|
+
pip install pyRDDLGym-jax[dashboard]
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
(Since version 1.0) To install JaxPlan with **all options**:
|
|
92
|
+
|
|
93
|
+
```shell
|
|
94
|
+
pip install pyRDDLGym-jax[extra,dashboard]
|
|
95
|
+
```
|
|
96
|
+
|
|
70
97
|
## Running from the Command Line
|
|
71
98
|
|
|
72
|
-
A basic run script is provided to run
|
|
99
|
+
A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
|
|
73
100
|
|
|
74
101
|
```shell
|
|
75
102
|
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
@@ -86,35 +113,15 @@ The ``method`` parameter supports three possible modes:
|
|
|
86
113
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
87
114
|
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
88
115
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
```shell
|
|
92
|
-
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
93
|
-
```
|
|
94
|
-
|
|
95
|
-
where:
|
|
96
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
97
|
-
- ``instance`` is the instance identifier
|
|
98
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
99
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
100
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
101
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
102
|
-
|
|
103
|
-
For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
116
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
104
117
|
|
|
105
118
|
```shell
|
|
106
119
|
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
107
120
|
```
|
|
108
121
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
<p align="center">
|
|
112
|
-
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
113
|
-
</p>
|
|
114
|
-
|
|
115
|
-
## Running from within Python
|
|
122
|
+
## Running from Another Python Application
|
|
116
123
|
|
|
117
|
-
To run
|
|
124
|
+
To run JaxPlan from within a Python application, refer to the following example:
|
|
118
125
|
|
|
119
126
|
```python
|
|
120
127
|
import pyRDDLGym
|
|
@@ -144,17 +151,15 @@ The basic structure of a configuration file is provided below for a straight-lin
|
|
|
144
151
|
```ini
|
|
145
152
|
[Model]
|
|
146
153
|
logic='FuzzyLogic'
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
154
|
+
comparison_kwargs={'weight': 20}
|
|
155
|
+
rounding_kwargs={'weight': 20}
|
|
156
|
+
control_kwargs={'weight': 20}
|
|
150
157
|
|
|
151
158
|
[Optimizer]
|
|
152
159
|
method='JaxStraightLinePlan'
|
|
153
160
|
method_kwargs={}
|
|
154
161
|
optimizer='rmsprop'
|
|
155
162
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
156
|
-
batch_size_train=1
|
|
157
|
-
batch_size_test=1
|
|
158
163
|
|
|
159
164
|
[Training]
|
|
160
165
|
key=42
|
|
@@ -179,7 +184,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
|
179
184
|
```
|
|
180
185
|
|
|
181
186
|
The configuration file must then be passed to the planner during initialization.
|
|
182
|
-
For example, the [previous script here](#running-from-
|
|
187
|
+
For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
|
|
183
188
|
|
|
184
189
|
```python
|
|
185
190
|
from pyRDDLGym_jax.core.planner import load_config
|
|
@@ -193,6 +198,101 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
193
198
|
...
|
|
194
199
|
```
|
|
195
200
|
|
|
201
|
+
### JaxPlan Dashboard
|
|
202
|
+
|
|
203
|
+
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
204
|
+
and visualization of the policy or model, and other useful debugging features.
|
|
205
|
+
|
|
206
|
+
<p align="middle">
|
|
207
|
+
<img src="Images/dashboard.png" width="480" height="248" margin=0/>
|
|
208
|
+
</p>
|
|
209
|
+
|
|
210
|
+
To run the dashboard, add the following entry to your config file:
|
|
211
|
+
|
|
212
|
+
```ini
|
|
213
|
+
...
|
|
214
|
+
[Training]
|
|
215
|
+
dashboard=True
|
|
216
|
+
...
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
More documentation about this and other new features will be coming soon.
|
|
220
|
+
|
|
221
|
+
### Tuning the Planner
|
|
222
|
+
|
|
223
|
+
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
224
|
+
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
225
|
+
|
|
226
|
+
```ini
|
|
227
|
+
[Model]
|
|
228
|
+
logic='FuzzyLogic'
|
|
229
|
+
comparison_kwargs={'weight': TUNABLE_WEIGHT}
|
|
230
|
+
rounding_kwargs={'weight': TUNABLE_WEIGHT}
|
|
231
|
+
control_kwargs={'weight': TUNABLE_WEIGHT}
|
|
232
|
+
|
|
233
|
+
[Optimizer]
|
|
234
|
+
method='JaxStraightLinePlan'
|
|
235
|
+
method_kwargs={}
|
|
236
|
+
optimizer='rmsprop'
|
|
237
|
+
optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
|
|
238
|
+
|
|
239
|
+
[Training]
|
|
240
|
+
train_seconds=30
|
|
241
|
+
print_summary=False
|
|
242
|
+
print_progress=False
|
|
243
|
+
train_on_reset=True
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
247
|
+
|
|
248
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
import pyRDDLGym
|
|
252
|
+
from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
|
|
253
|
+
|
|
254
|
+
# set up the environment
|
|
255
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
256
|
+
|
|
257
|
+
# load the config file template with planner settings
|
|
258
|
+
with open('path/to/config.cfg', 'r') as file:
|
|
259
|
+
config_template = file.read()
|
|
260
|
+
|
|
261
|
+
# map parameters in the config that will be tuned
|
|
262
|
+
def power_10(x):
|
|
263
|
+
return 10.0 ** x
|
|
264
|
+
|
|
265
|
+
hyperparams = [
|
|
266
|
+
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
267
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
# build the tuner and tune
|
|
271
|
+
tuning = JaxParameterTuning(env=env,
|
|
272
|
+
config_template=config_template,
|
|
273
|
+
hyperparams=hyperparams,
|
|
274
|
+
online=False,
|
|
275
|
+
eval_trials=trials,
|
|
276
|
+
num_workers=workers,
|
|
277
|
+
gp_iters=iters)
|
|
278
|
+
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
282
|
+
|
|
283
|
+
```shell
|
|
284
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
285
|
+
```
|
|
286
|
+
|
|
287
|
+
where:
|
|
288
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
289
|
+
- ``instance`` is the instance identifier
|
|
290
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
291
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
292
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
293
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
294
|
+
|
|
295
|
+
|
|
196
296
|
## Simulation
|
|
197
297
|
|
|
198
298
|
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
|
|
@@ -206,47 +306,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
|
206
306
|
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
207
307
|
|
|
208
308
|
# evaluate the random policy
|
|
209
|
-
agent = RandomAgent(action_space=env.action_space,
|
|
210
|
-
num_actions=env.max_allowed_actions)
|
|
309
|
+
agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
|
|
211
310
|
agent.evaluate(env, verbose=True, render=True)
|
|
212
311
|
```
|
|
213
312
|
|
|
214
313
|
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
215
314
|
In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
|
|
216
315
|
|
|
217
|
-
## Manual Gradient Calculation
|
|
218
|
-
|
|
219
|
-
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
|
|
220
|
-
Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
221
|
-
|
|
222
|
-
```python
|
|
223
|
-
import pyRDDLGym
|
|
224
|
-
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
225
|
-
|
|
226
|
-
# set up the environment
|
|
227
|
-
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
228
|
-
|
|
229
|
-
# create the step function
|
|
230
|
-
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
231
|
-
compiled.compile()
|
|
232
|
-
step_fn = compiled.compile_transition()
|
|
233
|
-
```
|
|
234
|
-
|
|
235
|
-
This will return a JAX compiled (pure) function requiring the following inputs:
|
|
236
|
-
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
237
|
-
- ``actions`` is the dictionary of action fluent tensors
|
|
238
|
-
- ``subs`` is the dictionary of state-fluent and non-fluent tensors
|
|
239
|
-
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
240
|
-
|
|
241
|
-
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
242
|
-
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
243
316
|
|
|
244
|
-
|
|
245
|
-
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
317
|
+
## Citing JaxPlan
|
|
246
318
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
|
|
319
|
+
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
|
|
250
320
|
|
|
251
321
|
```
|
|
252
322
|
@inproceedings{gimelfarb2024jaxplan,
|
|
@@ -258,21 +328,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
|
|
|
258
328
|
}
|
|
259
329
|
```
|
|
260
330
|
|
|
261
|
-
The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
|
|
262
|
-
|
|
263
|
-
```
|
|
264
|
-
@inproceedings{patton2022distributional,
|
|
265
|
-
title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
|
|
266
|
-
author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
|
|
267
|
-
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
|
268
|
-
volume={36},
|
|
269
|
-
number={9},
|
|
270
|
-
pages={9894--9901},
|
|
271
|
-
year={2022}
|
|
272
|
-
}
|
|
273
|
-
```
|
|
274
|
-
|
|
275
331
|
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
|
|
332
|
+
- [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
|
|
276
333
|
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
277
334
|
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
278
335
|
|