pyRDDLGym-jax 0.5__tar.gz → 1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/PKG-INFO +159 -103
  2. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/README.md +155 -102
  3. pyrddlgym_jax-1.1/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/compiler.py +463 -592
  5. pyrddlgym_jax-1.1/pyRDDLGym_jax/core/logic.py +1083 -0
  6. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/planner.py +336 -472
  7. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/simulator.py +7 -5
  8. pyrddlgym_jax-1.1/pyRDDLGym_jax/core/tuning.py +525 -0
  9. pyrddlgym_jax-1.1/pyRDDLGym_jax/core/visualization.py +1463 -0
  10. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  11. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  12. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  13. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  14. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  15. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  16. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  17. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  18. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  19. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  20. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  21. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  22. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  23. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  24. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  25. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  26. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  27. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  28. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  29. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  30. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  31. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  32. pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  33. pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  34. pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  35. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_plan.py +4 -1
  36. pyrddlgym_jax-1.1/pyRDDLGym_jax/examples/run_tune.py +91 -0
  37. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/PKG-INFO +159 -103
  38. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
  39. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/requires.txt +4 -0
  40. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/setup.py +3 -2
  41. pyrddlgym_jax-0.5/pyRDDLGym_jax/__init__.py +0 -1
  42. pyrddlgym_jax-0.5/pyRDDLGym_jax/core/logic.py +0 -843
  43. pyrddlgym_jax-0.5/pyRDDLGym_jax/core/tuning.py +0 -700
  44. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  45. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  46. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  47. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/run_tune.py +0 -78
  48. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/LICENSE +0 -0
  49. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/__init__.py +0 -0
  50. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/__init__.py +0 -0
  51. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  52. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  53. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  54. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  55. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  56. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  57. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 0.5
3
+ Version: 1.1
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -25,51 +25,77 @@ Requires-Dist: tensorflow-probability>=0.21.0
25
25
  Provides-Extra: extra
26
26
  Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
27
27
  Requires-Dist: rddlrepository>=2.0; extra == "extra"
28
+ Provides-Extra: dashboard
29
+ Requires-Dist: dash>=2.18.0; extra == "dashboard"
30
+ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
28
31
 
29
32
  # pyRDDLGym-jax
30
33
 
31
- Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
34
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
35
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
36
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
37
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
38
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
32
39
 
33
- This directory provides:
34
- 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
35
- 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
40
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
36
41
 
37
- > [!NOTE]
38
- > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
39
- > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
42
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
43
+
44
+ Purpose:
45
+
46
+ 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
47
+ 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
48
+
49
+ Some demos of solved problems by JaxPlan:
50
+
51
+ <p align="middle">
52
+ <img src="Images/intruders.gif" width="120" height="120" margin=0/>
53
+ <img src="Images/marsrover.gif" width="120" height="120" margin=0/>
54
+ <img src="Images/pong.gif" width="120" height="120" margin=0/>
55
+ <img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
56
+ <img src="Images/reacher.gif" width="120" height="120" margin=0/>
57
+ <img src="Images/reservoir.gif" width="120" height="120" margin=0/>
58
+ </p>
40
59
 
41
- ## Contents
60
+ > [!WARNING]
61
+ > Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
62
+ and was moved to the individual logic components which have their own unique weight parameter assigned.
63
+ > Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
64
+ > Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
42
65
 
43
- - [Installation](#installation)
44
- - [Running from the Command Line](#running-from-the-command-line)
45
- - [Running from within Python](#running-from-within-python)
46
- - [Configuring the Planner](#configuring-the-planner)
47
- - [Simulation](#simulation)
48
- - [Manual Gradient Calculation](#manual-gradient-calculation)
49
- - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
66
+ > [!NOTE]
67
+ > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
68
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
50
69
 
51
70
  ## Installation
52
71
 
53
- To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
54
- - ``pyRDDLGym>=2.0``
55
- - ``tqdm>=4.66``
56
- - ``jax>=0.4.12``
57
- - ``optax>=0.1.9``
58
- - ``dm-haiku>=0.0.10``
59
- - ``tensorflow-probability>=0.21.0``
72
+ To install the bare-bones version of JaxPlan with **minimum installation requirements**:
60
73
 
61
- Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
62
- To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
74
+ ```shell
75
+ pip install pyRDDLGym-jax
76
+ ```
63
77
 
64
- You can install pyRDDLGym-jax with all requirements using pip:
78
+ To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
65
79
 
66
80
  ```shell
67
81
  pip install pyRDDLGym-jax[extra]
68
82
  ```
69
83
 
84
+ (Since version 1.0) To install JaxPlan with the **visualization dashboard**:
85
+
86
+ ```shell
87
+ pip install pyRDDLGym-jax[dashboard]
88
+ ```
89
+
90
+ (Since version 1.0) To install JaxPlan with **all options**:
91
+
92
+ ```shell
93
+ pip install pyRDDLGym-jax[extra,dashboard]
94
+ ```
95
+
70
96
  ## Running from the Command Line
71
97
 
72
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
98
+ A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
73
99
 
74
100
  ```shell
75
101
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -86,35 +112,15 @@ The ``method`` parameter supports three possible modes:
86
112
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
87
113
  - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
88
114
 
89
- A basic run script is also provided to run the automatic hyper-parameter tuning:
90
-
91
- ```shell
92
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
93
- ```
94
-
95
- where:
96
- - ``domain`` is the domain identifier as specified in rddlrepository
97
- - ``instance`` is the instance identifier
98
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
99
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
100
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
101
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
102
-
103
- For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
115
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
104
116
 
105
117
  ```shell
106
118
  python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
107
119
  ```
108
120
 
109
- After several minutes of optimization, you should get a visualization as follows:
121
+ ## Running from Another Python Application
110
122
 
111
- <p align="center">
112
- <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
113
- </p>
114
-
115
- ## Running from within Python
116
-
117
- To run the Jax planner from within a Python application, refer to the following example:
123
+ To run JaxPlan from within a Python application, refer to the following example:
118
124
 
119
125
  ```python
120
126
  import pyRDDLGym
@@ -144,17 +150,15 @@ The basic structure of a configuration file is provided below for a straight-lin
144
150
  ```ini
145
151
  [Model]
146
152
  logic='FuzzyLogic'
147
- logic_kwargs={'weight': 20}
148
- tnorm='ProductTNorm'
149
- tnorm_kwargs={}
153
+ comparison_kwargs={'weight': 20}
154
+ rounding_kwargs={'weight': 20}
155
+ control_kwargs={'weight': 20}
150
156
 
151
157
  [Optimizer]
152
158
  method='JaxStraightLinePlan'
153
159
  method_kwargs={}
154
160
  optimizer='rmsprop'
155
161
  optimizer_kwargs={'learning_rate': 0.001}
156
- batch_size_train=1
157
- batch_size_test=1
158
162
 
159
163
  [Training]
160
164
  key=42
@@ -179,7 +183,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
179
183
  ```
180
184
 
181
185
  The configuration file must then be passed to the planner during initialization.
182
- For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
186
+ For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
183
187
 
184
188
  ```python
185
189
  from pyRDDLGym_jax.core.planner import load_config
@@ -193,6 +197,101 @@ controller = JaxOfflineController(planner, **train_args)
193
197
  ...
194
198
  ```
195
199
 
200
+ ### JaxPlan Dashboard
201
+
202
+ Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
203
+ and visualization of the policy or model, and other useful debugging features.
204
+
205
+ <p align="middle">
206
+ <img src="Images/dashboard.png" width="480" height="248" margin=0/>
207
+ </p>
208
+
209
+ To run the dashboard, add the following entry to your config file:
210
+
211
+ ```ini
212
+ ...
213
+ [Training]
214
+ dashboard=True
215
+ ...
216
+ ```
217
+
218
+ More documentation about this and other new features will be coming soon.
219
+
220
+ ### Tuning the Planner
221
+
222
+ It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
223
+ To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
224
+
225
+ ```ini
226
+ [Model]
227
+ logic='FuzzyLogic'
228
+ comparison_kwargs={'weight': TUNABLE_WEIGHT}
229
+ rounding_kwargs={'weight': TUNABLE_WEIGHT}
230
+ control_kwargs={'weight': TUNABLE_WEIGHT}
231
+
232
+ [Optimizer]
233
+ method='JaxStraightLinePlan'
234
+ method_kwargs={}
235
+ optimizer='rmsprop'
236
+ optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
237
+
238
+ [Training]
239
+ train_seconds=30
240
+ print_summary=False
241
+ print_progress=False
242
+ train_on_reset=True
243
+ ```
244
+
245
+ would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
246
+
247
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
248
+
249
+ ```python
250
+ import pyRDDLGym
251
+ from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
252
+
253
+ # set up the environment
254
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
255
+
256
+ # load the config file template with planner settings
257
+ with open('path/to/config.cfg', 'r') as file:
258
+ config_template = file.read()
259
+
260
+ # map parameters in the config that will be tuned
261
+ def power_10(x):
262
+ return 10.0 ** x
263
+
264
+ hyperparams = [
265
+ Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
266
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
267
+ ]
268
+
269
+ # build the tuner and tune
270
+ tuning = JaxParameterTuning(env=env,
271
+ config_template=config_template,
272
+ hyperparams=hyperparams,
273
+ online=False,
274
+ eval_trials=trials,
275
+ num_workers=workers,
276
+ gp_iters=iters)
277
+ tuning.tune(key=42, log_file='path/to/log.csv')
278
+ ```
279
+
280
+ A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
281
+
282
+ ```shell
283
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
284
+ ```
285
+
286
+ where:
287
+ - ``domain`` is the domain identifier as specified in rddlrepository
288
+ - ``instance`` is the instance identifier
289
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
290
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
291
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
292
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
293
+
294
+
196
295
  ## Simulation
197
296
 
198
297
  The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
@@ -206,47 +305,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
206
305
  env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
207
306
 
208
307
  # evaluate the random policy
209
- agent = RandomAgent(action_space=env.action_space,
210
- num_actions=env.max_allowed_actions)
308
+ agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
211
309
  agent.evaluate(env, verbose=True, render=True)
212
310
  ```
213
311
 
214
312
  For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
215
313
  In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
216
314
 
217
- ## Manual Gradient Calculation
218
-
219
- For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
220
- Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
221
315
 
222
- ```python
223
- import pyRDDLGym
224
- from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
316
+ ## Citing JaxPlan
225
317
 
226
- # set up the environment
227
- env = pyRDDLGym.make("domain", "instance", vectorized=True)
228
-
229
- # create the step function
230
- compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
231
- compiled.compile()
232
- step_fn = compiled.compile_transition()
233
- ```
234
-
235
- This will return a JAX compiled (pure) function requiring the following inputs:
236
- - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
237
- - ``actions`` is the dictionary of action fluent tensors
238
- - ``subs`` is the dictionary of state-fluent and non-fluent tensors
239
- - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
240
-
241
- The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
242
- It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
243
-
244
- Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
245
- An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
246
-
247
- ## Citing pyRDDLGym-jax
248
-
249
- The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
318
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
250
319
 
251
320
  ```
252
321
  @inproceedings{gimelfarb2024jaxplan,
@@ -258,21 +327,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
258
327
  }
259
328
  ```
260
329
 
261
- The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
262
-
263
- ```
264
- @inproceedings{patton2022distributional,
265
- title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
266
- author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
267
- booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
268
- volume={36},
269
- number={9},
270
- pages={9894--9901},
271
- year={2022}
272
- }
273
- ```
274
-
275
330
  Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
331
+ - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
276
332
  - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
277
333
  - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
278
334