pyRDDLGym-jax 0.4__tar.gz → 1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/PKG-INFO +158 -99
  2. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/README.md +150 -96
  3. pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/compiler.py +463 -592
  5. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/logic.py +1083 -0
  6. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/planner.py +422 -474
  7. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/simulator.py +7 -5
  8. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/tuning.py +511 -0
  9. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/visualization.py +1463 -0
  10. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  11. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
  12. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  13. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  14. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  15. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  16. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  17. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  18. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
  19. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  20. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  21. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  22. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  23. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
  24. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  25. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  26. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  27. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +21 -0
  28. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  29. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  30. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
  31. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  32. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  33. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  34. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  35. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_plan.py +4 -1
  36. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/run_tune.py +91 -0
  37. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/PKG-INFO +158 -99
  38. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
  39. pyrddlgym_jax-1.0/pyRDDLGym_jax.egg-info/requires.txt +14 -0
  40. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/setup.py +7 -5
  41. pyrddlgym_jax-0.4/pyRDDLGym_jax/__init__.py +0 -1
  42. pyrddlgym_jax-0.4/pyRDDLGym_jax/core/logic.py +0 -781
  43. pyrddlgym_jax-0.4/pyRDDLGym_jax/core/tuning.py +0 -705
  44. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  45. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  46. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  47. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -20
  48. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/run_tune.py +0 -80
  49. pyrddlgym_jax-0.4/pyRDDLGym_jax.egg-info/requires.txt +0 -7
  50. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/LICENSE +0 -0
  51. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/__init__.py +0 -0
  52. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/__init__.py +0 -0
  53. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  54. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  55. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  56. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  57. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  58. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  59. {pyrddlgym_jax-0.4 → pyrddlgym_jax-1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 0.4
3
+ Version: 1.0
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -13,61 +13,90 @@ Classifier: Natural Language :: English
13
13
  Classifier: Operating System :: OS Independent
14
14
  Classifier: Programming Language :: Python :: 3
15
15
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
- Requires-Python: >=3.8
16
+ Requires-Python: >=3.9
17
17
  Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: pyRDDLGym>=2.0
20
20
  Requires-Dist: tqdm>=4.66
21
- Requires-Dist: bayesian-optimization>=1.4.3
22
21
  Requires-Dist: jax>=0.4.12
23
22
  Requires-Dist: optax>=0.1.9
24
23
  Requires-Dist: dm-haiku>=0.0.10
25
24
  Requires-Dist: tensorflow-probability>=0.21.0
25
+ Provides-Extra: extra
26
+ Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
27
+ Requires-Dist: rddlrepository>=2.0; extra == "extra"
28
+ Provides-Extra: dashboard
29
+ Requires-Dist: dash>=2.18.0; extra == "dashboard"
30
+ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
26
31
 
27
32
  # pyRDDLGym-jax
28
33
 
29
- Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
34
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
30
35
 
31
- This directory provides:
32
- 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
33
- 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
36
+ 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
+ 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
38
+
39
+ Some demos of solved problems by JaxPlan:
40
+
41
+ <p align="middle">
42
+ <img src="Images/intruders.gif" width="120" height="120" margin=0/>
43
+ <img src="Images/marsrover.gif" width="120" height="120" margin=0/>
44
+ <img src="Images/pong.gif" width="120" height="120" margin=0/>
45
+ <img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
46
+ <img src="Images/reacher.gif" width="120" height="120" margin=0/>
47
+ <img src="Images/reservoir.gif" width="120" height="120" margin=0/>
48
+ </p>
49
+
50
+ > [!WARNING]
51
+ > Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
52
+ and was moved to the individual logic components which have their own unique weight parameter assigned.
53
+ > Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
54
+ > Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
34
55
 
35
56
  > [!NOTE]
36
- > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
57
+ > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
37
58
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
38
59
 
39
60
  ## Contents
40
61
 
41
62
  - [Installation](#installation)
42
63
  - [Running from the Command Line](#running-from-the-command-line)
43
- - [Running from within Python](#running-from-within-python)
64
+ - [Running from Another Python Application](#running-from-another-python-application)
44
65
  - [Configuring the Planner](#configuring-the-planner)
66
+ - [JaxPlan Dashboard](#jaxplan-dashboard)
67
+ - [Tuning the Planner](#tuning-the-planner)
45
68
  - [Simulation](#simulation)
46
- - [Manual Gradient Calculation](#manual-gradient-calculation)
47
- - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
69
+ - [Citing JaxPlan](#citing-jaxplan)
48
70
 
49
71
  ## Installation
50
72
 
51
- To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
52
- - ``pyRDDLGym>=2.0``
53
- - ``tqdm>=4.66``
54
- - ``jax>=0.4.12``
55
- - ``optax>=0.1.9``
56
- - ``dm-haiku>=0.0.10``
57
- - ``tensorflow-probability>=0.21.0``
73
+ To install the bare-bones version of JaxPlan with **minimum installation requirements**:
74
+
75
+ ```shell
76
+ pip install pyRDDLGym-jax
77
+ ```
78
+
79
+ To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
58
80
 
59
- Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
60
- To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
81
+ ```shell
82
+ pip install pyRDDLGym-jax[extra]
83
+ ```
61
84
 
62
- You can install this package, together with all of its requirements, via pip:
85
+ (Since version 1.0) To install JaxPlan with the **visualization dashboard**:
63
86
 
64
87
  ```shell
65
- pip install rddlrepository pyRDDLGym-jax
88
+ pip install pyRDDLGym-jax[dashboard]
89
+ ```
90
+
91
+ (Since version 1.0) To install JaxPlan with **all options**:
92
+
93
+ ```shell
94
+ pip install pyRDDLGym-jax[extra,dashboard]
66
95
  ```
67
96
 
68
97
  ## Running from the Command Line
69
98
 
70
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
99
+ A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
71
100
 
72
101
  ```shell
73
102
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -84,35 +113,15 @@ The ``method`` parameter supports three possible modes:
84
113
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
85
114
  - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
86
115
 
87
- A basic run script is also provided to run the automatic hyper-parameter tuning:
88
-
89
- ```shell
90
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
91
- ```
92
-
93
- where:
94
- - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
95
- - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
96
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
97
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
98
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
99
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
100
-
101
- For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
116
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
102
117
 
103
118
  ```shell
104
119
  python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
105
120
  ```
106
121
 
107
- After several minutes of optimization, you should get a visualization as follows:
108
-
109
- <p align="center">
110
- <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
111
- </p>
112
-
113
- ## Running from within Python
122
+ ## Running from Another Python Application
114
123
 
115
- To run the Jax planner from within a Python application, refer to the following example:
124
+ To run JaxPlan from within a Python application, refer to the following example:
116
125
 
117
126
  ```python
118
127
  import pyRDDLGym
@@ -142,17 +151,15 @@ The basic structure of a configuration file is provided below for a straight-lin
142
151
  ```ini
143
152
  [Model]
144
153
  logic='FuzzyLogic'
145
- logic_kwargs={'weight': 20}
146
- tnorm='ProductTNorm'
147
- tnorm_kwargs={}
154
+ comparison_kwargs={'weight': 20}
155
+ rounding_kwargs={'weight': 20}
156
+ control_kwargs={'weight': 20}
148
157
 
149
158
  [Optimizer]
150
159
  method='JaxStraightLinePlan'
151
160
  method_kwargs={}
152
161
  optimizer='rmsprop'
153
162
  optimizer_kwargs={'learning_rate': 0.001}
154
- batch_size_train=1
155
- batch_size_test=1
156
163
 
157
164
  [Training]
158
165
  key=42
@@ -177,7 +184,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
177
184
  ```
178
185
 
179
186
  The configuration file must then be passed to the planner during initialization.
180
- For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
187
+ For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
181
188
 
182
189
  ```python
183
190
  from pyRDDLGym_jax.core.planner import load_config
@@ -191,6 +198,101 @@ controller = JaxOfflineController(planner, **train_args)
191
198
  ...
192
199
  ```
193
200
 
201
+ ### JaxPlan Dashboard
202
+
203
+ Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
204
+ and visualization of the policy or model, and other useful debugging features.
205
+
206
+ <p align="middle">
207
+ <img src="Images/dashboard.png" width="480" height="248" margin=0/>
208
+ </p>
209
+
210
+ To run the dashboard, add the following entry to your config file:
211
+
212
+ ```ini
213
+ ...
214
+ [Training]
215
+ dashboard=True
216
+ ...
217
+ ```
218
+
219
+ More documentation about this and other new features will be coming soon.
220
+
221
+ ### Tuning the Planner
222
+
223
+ It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
224
+ To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
225
+
226
+ ```ini
227
+ [Model]
228
+ logic='FuzzyLogic'
229
+ comparison_kwargs={'weight': TUNABLE_WEIGHT}
230
+ rounding_kwargs={'weight': TUNABLE_WEIGHT}
231
+ control_kwargs={'weight': TUNABLE_WEIGHT}
232
+
233
+ [Optimizer]
234
+ method='JaxStraightLinePlan'
235
+ method_kwargs={}
236
+ optimizer='rmsprop'
237
+ optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
238
+
239
+ [Training]
240
+ train_seconds=30
241
+ print_summary=False
242
+ print_progress=False
243
+ train_on_reset=True
244
+ ```
245
+
246
+ would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
247
+
248
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
249
+
250
+ ```python
251
+ import pyRDDLGym
252
+ from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
253
+
254
+ # set up the environment
255
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
256
+
257
+ # load the config file template with planner settings
258
+ with open('path/to/config.cfg', 'r') as file:
259
+ config_template = file.read()
260
+
261
+ # map parameters in the config that will be tuned
262
+ def power_10(x):
263
+ return 10.0 ** x
264
+
265
+ hyperparams = [
266
+ Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
267
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
268
+ ]
269
+
270
+ # build the tuner and tune
271
+ tuning = JaxParameterTuning(env=env,
272
+ config_template=config_template,
273
+ hyperparams=hyperparams,
274
+ online=False,
275
+ eval_trials=trials,
276
+ num_workers=workers,
277
+ gp_iters=iters)
278
+ tuning.tune(key=42, log_file='path/to/log.csv')
279
+ ```
280
+
281
+ A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
282
+
283
+ ```shell
284
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
285
+ ```
286
+
287
+ where:
288
+ - ``domain`` is the domain identifier as specified in rddlrepository
289
+ - ``instance`` is the instance identifier
290
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
291
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
292
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
293
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
294
+
295
+
194
296
  ## Simulation
195
297
 
196
298
  The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
@@ -204,47 +306,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
204
306
  env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
205
307
 
206
308
  # evaluate the random policy
207
- agent = RandomAgent(action_space=env.action_space,
208
- num_actions=env.max_allowed_actions)
309
+ agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
209
310
  agent.evaluate(env, verbose=True, render=True)
210
311
  ```
211
312
 
212
313
  For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
213
314
  In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
214
315
 
215
- ## Manual Gradient Calculation
216
-
217
- For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
218
- Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
219
316
 
220
- ```python
221
- import pyRDDLGym
222
- from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
317
+ ## Citing JaxPlan
223
318
 
224
- # set up the environment
225
- env = pyRDDLGym.make("domain", "instance", vectorized=True)
226
-
227
- # create the step function
228
- compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
229
- compiled.compile()
230
- step_fn = compiled.compile_transition()
231
- ```
232
-
233
- This will return a JAX compiled (pure) function requiring the following inputs:
234
- - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
235
- - ``actions`` is the dictionary of action fluent tensors
236
- - ``subs`` is the dictionary of state-fluent and non-fluent tensors
237
- - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
238
-
239
- The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
240
- It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
241
-
242
- Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
243
- An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
244
-
245
- ## Citing pyRDDLGym-jax
246
-
247
- The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
319
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
248
320
 
249
321
  ```
250
322
  @inproceedings{gimelfarb2024jaxplan,
@@ -256,21 +328,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
256
328
  }
257
329
  ```
258
330
 
259
- The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
260
-
261
- ```
262
- @inproceedings{patton2022distributional,
263
- title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
264
- author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
265
- booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
266
- volume={36},
267
- number={9},
268
- pages={9894--9901},
269
- year={2022}
270
- }
271
- ```
272
-
273
331
  Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
332
+ - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
274
333
  - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
275
334
  - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
276
335