pyRDDLGym-jax 0.4__tar.gz → 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/PKG-INFO +158 -99
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/README.md +150 -96
- pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/compiler.py +463 -592
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/logic.py +1083 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/planner.py +422 -474
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/simulator.py +7 -5
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/tuning.py +511 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/core/visualization.py +1463 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +21 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/run_tune.py +91 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/PKG-INFO +158 -99
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
- pyrddlgym_jax-1.0/pyRDDLGym_jax.egg-info/requires.txt +14 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/setup.py +7 -5
- pyrddlgym_jax-0.4/pyRDDLGym_jax/__init__.py +0 -1
- pyrddlgym_jax-0.4/pyRDDLGym_jax/core/logic.py +0 -781
- pyrddlgym_jax-0.4/pyRDDLGym_jax/core/tuning.py +0 -705
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -20
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/run_tune.py +0 -80
- pyrddlgym_jax-0.4/pyRDDLGym_jax.egg-info/requires.txt +0 -7
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/LICENSE +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 0
|
|
3
|
+
Version: 1.0
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -13,61 +13,90 @@ Classifier: Natural Language :: English
|
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
15
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
-
Requires-Python: >=3.
|
|
16
|
+
Requires-Python: >=3.9
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: pyRDDLGym>=2.0
|
|
20
20
|
Requires-Dist: tqdm>=4.66
|
|
21
|
-
Requires-Dist: bayesian-optimization>=1.4.3
|
|
22
21
|
Requires-Dist: jax>=0.4.12
|
|
23
22
|
Requires-Dist: optax>=0.1.9
|
|
24
23
|
Requires-Dist: dm-haiku>=0.0.10
|
|
25
24
|
Requires-Dist: tensorflow-probability>=0.21.0
|
|
25
|
+
Provides-Extra: extra
|
|
26
|
+
Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
|
|
27
|
+
Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
28
|
+
Provides-Extra: dashboard
|
|
29
|
+
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
|
+
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
26
31
|
|
|
27
32
|
# pyRDDLGym-jax
|
|
28
33
|
|
|
29
|
-
|
|
34
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
36
|
+
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
|
+
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
38
|
+
|
|
39
|
+
Some demos of solved problems by JaxPlan:
|
|
40
|
+
|
|
41
|
+
<p align="middle">
|
|
42
|
+
<img src="Images/intruders.gif" width="120" height="120" margin=0/>
|
|
43
|
+
<img src="Images/marsrover.gif" width="120" height="120" margin=0/>
|
|
44
|
+
<img src="Images/pong.gif" width="120" height="120" margin=0/>
|
|
45
|
+
<img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
|
|
46
|
+
<img src="Images/reacher.gif" width="120" height="120" margin=0/>
|
|
47
|
+
<img src="Images/reservoir.gif" width="120" height="120" margin=0/>
|
|
48
|
+
</p>
|
|
49
|
+
|
|
50
|
+
> [!WARNING]
|
|
51
|
+
> Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
|
|
52
|
+
and was moved to the individual logic components which have their own unique weight parameter assigned.
|
|
53
|
+
> Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
|
|
54
|
+
> Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
|
|
34
55
|
|
|
35
56
|
> [!NOTE]
|
|
36
|
-
> While
|
|
57
|
+
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
37
58
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
38
59
|
|
|
39
60
|
## Contents
|
|
40
61
|
|
|
41
62
|
- [Installation](#installation)
|
|
42
63
|
- [Running from the Command Line](#running-from-the-command-line)
|
|
43
|
-
- [Running from
|
|
64
|
+
- [Running from Another Python Application](#running-from-another-python-application)
|
|
44
65
|
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
+
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
+
- [Tuning the Planner](#tuning-the-planner)
|
|
45
68
|
- [Simulation](#simulation)
|
|
46
|
-
- [
|
|
47
|
-
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
69
|
+
- [Citing JaxPlan](#citing-jaxplan)
|
|
48
70
|
|
|
49
71
|
## Installation
|
|
50
72
|
|
|
51
|
-
To
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
-
|
|
73
|
+
To install the bare-bones version of JaxPlan with **minimum installation requirements**:
|
|
74
|
+
|
|
75
|
+
```shell
|
|
76
|
+
pip install pyRDDLGym-jax
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
|
|
58
80
|
|
|
59
|
-
|
|
60
|
-
|
|
81
|
+
```shell
|
|
82
|
+
pip install pyRDDLGym-jax[extra]
|
|
83
|
+
```
|
|
61
84
|
|
|
62
|
-
|
|
85
|
+
(Since version 1.0) To install JaxPlan with the **visualization dashboard**:
|
|
63
86
|
|
|
64
87
|
```shell
|
|
65
|
-
pip install
|
|
88
|
+
pip install pyRDDLGym-jax[dashboard]
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
(Since version 1.0) To install JaxPlan with **all options**:
|
|
92
|
+
|
|
93
|
+
```shell
|
|
94
|
+
pip install pyRDDLGym-jax[extra,dashboard]
|
|
66
95
|
```
|
|
67
96
|
|
|
68
97
|
## Running from the Command Line
|
|
69
98
|
|
|
70
|
-
A basic run script is provided to run
|
|
99
|
+
A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
|
|
71
100
|
|
|
72
101
|
```shell
|
|
73
102
|
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
@@ -84,35 +113,15 @@ The ``method`` parameter supports three possible modes:
|
|
|
84
113
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
85
114
|
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
86
115
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
```shell
|
|
90
|
-
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
91
|
-
```
|
|
92
|
-
|
|
93
|
-
where:
|
|
94
|
-
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
|
|
95
|
-
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
|
|
96
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
97
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
98
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
99
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
100
|
-
|
|
101
|
-
For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
116
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
102
117
|
|
|
103
118
|
```shell
|
|
104
119
|
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
105
120
|
```
|
|
106
121
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
<p align="center">
|
|
110
|
-
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
111
|
-
</p>
|
|
112
|
-
|
|
113
|
-
## Running from within Python
|
|
122
|
+
## Running from Another Python Application
|
|
114
123
|
|
|
115
|
-
To run
|
|
124
|
+
To run JaxPlan from within a Python application, refer to the following example:
|
|
116
125
|
|
|
117
126
|
```python
|
|
118
127
|
import pyRDDLGym
|
|
@@ -142,17 +151,15 @@ The basic structure of a configuration file is provided below for a straight-lin
|
|
|
142
151
|
```ini
|
|
143
152
|
[Model]
|
|
144
153
|
logic='FuzzyLogic'
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
154
|
+
comparison_kwargs={'weight': 20}
|
|
155
|
+
rounding_kwargs={'weight': 20}
|
|
156
|
+
control_kwargs={'weight': 20}
|
|
148
157
|
|
|
149
158
|
[Optimizer]
|
|
150
159
|
method='JaxStraightLinePlan'
|
|
151
160
|
method_kwargs={}
|
|
152
161
|
optimizer='rmsprop'
|
|
153
162
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
154
|
-
batch_size_train=1
|
|
155
|
-
batch_size_test=1
|
|
156
163
|
|
|
157
164
|
[Training]
|
|
158
165
|
key=42
|
|
@@ -177,7 +184,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
|
177
184
|
```
|
|
178
185
|
|
|
179
186
|
The configuration file must then be passed to the planner during initialization.
|
|
180
|
-
For example, the [previous script here](#running-from-
|
|
187
|
+
For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
|
|
181
188
|
|
|
182
189
|
```python
|
|
183
190
|
from pyRDDLGym_jax.core.planner import load_config
|
|
@@ -191,6 +198,101 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
191
198
|
...
|
|
192
199
|
```
|
|
193
200
|
|
|
201
|
+
### JaxPlan Dashboard
|
|
202
|
+
|
|
203
|
+
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
204
|
+
and visualization of the policy or model, and other useful debugging features.
|
|
205
|
+
|
|
206
|
+
<p align="middle">
|
|
207
|
+
<img src="Images/dashboard.png" width="480" height="248" margin=0/>
|
|
208
|
+
</p>
|
|
209
|
+
|
|
210
|
+
To run the dashboard, add the following entry to your config file:
|
|
211
|
+
|
|
212
|
+
```ini
|
|
213
|
+
...
|
|
214
|
+
[Training]
|
|
215
|
+
dashboard=True
|
|
216
|
+
...
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
More documentation about this and other new features will be coming soon.
|
|
220
|
+
|
|
221
|
+
### Tuning the Planner
|
|
222
|
+
|
|
223
|
+
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
224
|
+
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
225
|
+
|
|
226
|
+
```ini
|
|
227
|
+
[Model]
|
|
228
|
+
logic='FuzzyLogic'
|
|
229
|
+
comparison_kwargs={'weight': TUNABLE_WEIGHT}
|
|
230
|
+
rounding_kwargs={'weight': TUNABLE_WEIGHT}
|
|
231
|
+
control_kwargs={'weight': TUNABLE_WEIGHT}
|
|
232
|
+
|
|
233
|
+
[Optimizer]
|
|
234
|
+
method='JaxStraightLinePlan'
|
|
235
|
+
method_kwargs={}
|
|
236
|
+
optimizer='rmsprop'
|
|
237
|
+
optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
|
|
238
|
+
|
|
239
|
+
[Training]
|
|
240
|
+
train_seconds=30
|
|
241
|
+
print_summary=False
|
|
242
|
+
print_progress=False
|
|
243
|
+
train_on_reset=True
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
247
|
+
|
|
248
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
import pyRDDLGym
|
|
252
|
+
from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
|
|
253
|
+
|
|
254
|
+
# set up the environment
|
|
255
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
256
|
+
|
|
257
|
+
# load the config file template with planner settings
|
|
258
|
+
with open('path/to/config.cfg', 'r') as file:
|
|
259
|
+
config_template = file.read()
|
|
260
|
+
|
|
261
|
+
# map parameters in the config that will be tuned
|
|
262
|
+
def power_10(x):
|
|
263
|
+
return 10.0 ** x
|
|
264
|
+
|
|
265
|
+
hyperparams = [
|
|
266
|
+
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
267
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
# build the tuner and tune
|
|
271
|
+
tuning = JaxParameterTuning(env=env,
|
|
272
|
+
config_template=config_template,
|
|
273
|
+
hyperparams=hyperparams,
|
|
274
|
+
online=False,
|
|
275
|
+
eval_trials=trials,
|
|
276
|
+
num_workers=workers,
|
|
277
|
+
gp_iters=iters)
|
|
278
|
+
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
282
|
+
|
|
283
|
+
```shell
|
|
284
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
285
|
+
```
|
|
286
|
+
|
|
287
|
+
where:
|
|
288
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
289
|
+
- ``instance`` is the instance identifier
|
|
290
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
291
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
292
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
293
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
294
|
+
|
|
295
|
+
|
|
194
296
|
## Simulation
|
|
195
297
|
|
|
196
298
|
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
|
|
@@ -204,47 +306,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
|
204
306
|
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
205
307
|
|
|
206
308
|
# evaluate the random policy
|
|
207
|
-
agent = RandomAgent(action_space=env.action_space,
|
|
208
|
-
num_actions=env.max_allowed_actions)
|
|
309
|
+
agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
|
|
209
310
|
agent.evaluate(env, verbose=True, render=True)
|
|
210
311
|
```
|
|
211
312
|
|
|
212
313
|
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
213
314
|
In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
|
|
214
315
|
|
|
215
|
-
## Manual Gradient Calculation
|
|
216
|
-
|
|
217
|
-
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
|
|
218
|
-
Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
219
316
|
|
|
220
|
-
|
|
221
|
-
import pyRDDLGym
|
|
222
|
-
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
317
|
+
## Citing JaxPlan
|
|
223
318
|
|
|
224
|
-
|
|
225
|
-
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
226
|
-
|
|
227
|
-
# create the step function
|
|
228
|
-
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
229
|
-
compiled.compile()
|
|
230
|
-
step_fn = compiled.compile_transition()
|
|
231
|
-
```
|
|
232
|
-
|
|
233
|
-
This will return a JAX compiled (pure) function requiring the following inputs:
|
|
234
|
-
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
235
|
-
- ``actions`` is the dictionary of action fluent tensors
|
|
236
|
-
- ``subs`` is the dictionary of state-fluent and non-fluent tensors
|
|
237
|
-
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
238
|
-
|
|
239
|
-
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
240
|
-
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
241
|
-
|
|
242
|
-
Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
|
|
243
|
-
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
244
|
-
|
|
245
|
-
## Citing pyRDDLGym-jax
|
|
246
|
-
|
|
247
|
-
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
|
|
319
|
+
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
|
|
248
320
|
|
|
249
321
|
```
|
|
250
322
|
@inproceedings{gimelfarb2024jaxplan,
|
|
@@ -256,21 +328,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
|
|
|
256
328
|
}
|
|
257
329
|
```
|
|
258
330
|
|
|
259
|
-
The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
|
|
260
|
-
|
|
261
|
-
```
|
|
262
|
-
@inproceedings{patton2022distributional,
|
|
263
|
-
title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
|
|
264
|
-
author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
|
|
265
|
-
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
|
266
|
-
volume={36},
|
|
267
|
-
number={9},
|
|
268
|
-
pages={9894--9901},
|
|
269
|
-
year={2022}
|
|
270
|
-
}
|
|
271
|
-
```
|
|
272
|
-
|
|
273
331
|
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
|
|
332
|
+
- [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
|
|
274
333
|
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
275
334
|
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
276
335
|
|