pyRDDLGym-jax 0.5__tar.gz → 1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/PKG-INFO +153 -96
  2. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/README.md +149 -95
  3. pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/compiler.py +463 -592
  5. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/logic.py +1083 -0
  6. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/planner.py +329 -463
  7. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/simulator.py +7 -5
  8. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/tuning.py +511 -0
  9. pyrddlgym_jax-1.0/pyRDDLGym_jax/core/visualization.py +1463 -0
  10. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  11. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  12. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  13. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  14. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  15. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  16. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  17. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  18. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  19. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  20. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  21. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  22. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  23. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  24. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  25. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  26. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  27. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  28. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  29. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  30. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  31. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  32. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  33. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  34. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  35. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_plan.py +4 -1
  36. pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/run_tune.py +91 -0
  37. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/PKG-INFO +153 -96
  38. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/SOURCES.txt +5 -4
  39. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/requires.txt +4 -0
  40. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/setup.py +3 -2
  41. pyrddlgym_jax-0.5/pyRDDLGym_jax/__init__.py +0 -1
  42. pyrddlgym_jax-0.5/pyRDDLGym_jax/core/logic.py +0 -843
  43. pyrddlgym_jax-0.5/pyRDDLGym_jax/core/tuning.py +0 -700
  44. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  45. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  46. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  47. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/run_tune.py +0 -78
  48. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/LICENSE +0 -0
  49. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/core/__init__.py +0 -0
  50. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/__init__.py +0 -0
  51. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  52. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  53. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  54. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  55. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  56. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  57. {pyrddlgym_jax-0.5 → pyrddlgym_jax-1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 0.5
3
+ Version: 1.0
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -25,51 +25,78 @@ Requires-Dist: tensorflow-probability>=0.21.0
25
25
  Provides-Extra: extra
26
26
  Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
27
27
  Requires-Dist: rddlrepository>=2.0; extra == "extra"
28
+ Provides-Extra: dashboard
29
+ Requires-Dist: dash>=2.18.0; extra == "dashboard"
30
+ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
28
31
 
29
32
  # pyRDDLGym-jax
30
33
 
31
- Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
34
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
32
35
 
33
- This directory provides:
34
- 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
35
- 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
36
+ 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
+ 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
38
+
39
+ Some demos of solved problems by JaxPlan:
40
+
41
+ <p align="middle">
42
+ <img src="Images/intruders.gif" width="120" height="120" margin=0/>
43
+ <img src="Images/marsrover.gif" width="120" height="120" margin=0/>
44
+ <img src="Images/pong.gif" width="120" height="120" margin=0/>
45
+ <img src="Images/quadcopter.gif" width="120" height="120" margin=0/>
46
+ <img src="Images/reacher.gif" width="120" height="120" margin=0/>
47
+ <img src="Images/reservoir.gif" width="120" height="120" margin=0/>
48
+ </p>
49
+
50
+ > [!WARNING]
51
+ > Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
52
+ and was moved to the individual logic components which have their own unique weight parameter assigned.
53
+ > Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates!
54
+ > Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!
36
55
 
37
56
  > [!NOTE]
38
- > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
57
+ > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
39
58
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
40
59
 
41
60
  ## Contents
42
61
 
43
62
  - [Installation](#installation)
44
63
  - [Running from the Command Line](#running-from-the-command-line)
45
- - [Running from within Python](#running-from-within-python)
64
+ - [Running from Another Python Application](#running-from-another-python-application)
46
65
  - [Configuring the Planner](#configuring-the-planner)
66
+ - [JaxPlan Dashboard](#jaxplan-dashboard)
67
+ - [Tuning the Planner](#tuning-the-planner)
47
68
  - [Simulation](#simulation)
48
- - [Manual Gradient Calculation](#manual-gradient-calculation)
49
- - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
69
+ - [Citing JaxPlan](#citing-jaxplan)
50
70
 
51
71
  ## Installation
52
72
 
53
- To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
54
- - ``pyRDDLGym>=2.0``
55
- - ``tqdm>=4.66``
56
- - ``jax>=0.4.12``
57
- - ``optax>=0.1.9``
58
- - ``dm-haiku>=0.0.10``
59
- - ``tensorflow-probability>=0.21.0``
73
+ To install the bare-bones version of JaxPlan with **minimum installation requirements**:
60
74
 
61
- Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
62
- To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
75
+ ```shell
76
+ pip install pyRDDLGym-jax
77
+ ```
63
78
 
64
- You can install pyRDDLGym-jax with all requirements using pip:
79
+ To install JaxPlan with the **automatic hyper-parameter tuning** and rddlrepository:
65
80
 
66
81
  ```shell
67
82
  pip install pyRDDLGym-jax[extra]
68
83
  ```
69
84
 
85
+ (Since version 1.0) To install JaxPlan with the **visualization dashboard**:
86
+
87
+ ```shell
88
+ pip install pyRDDLGym-jax[dashboard]
89
+ ```
90
+
91
+ (Since version 1.0) To install JaxPlan with **all options**:
92
+
93
+ ```shell
94
+ pip install pyRDDLGym-jax[extra,dashboard]
95
+ ```
96
+
70
97
  ## Running from the Command Line
71
98
 
72
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
99
+ A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
73
100
 
74
101
  ```shell
75
102
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -86,35 +113,15 @@ The ``method`` parameter supports three possible modes:
86
113
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
87
114
  - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
88
115
 
89
- A basic run script is also provided to run the automatic hyper-parameter tuning:
90
-
91
- ```shell
92
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
93
- ```
94
-
95
- where:
96
- - ``domain`` is the domain identifier as specified in rddlrepository
97
- - ``instance`` is the instance identifier
98
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
99
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
100
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
101
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
102
-
103
- For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
116
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
104
117
 
105
118
  ```shell
106
119
  python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
107
120
  ```
108
121
 
109
- After several minutes of optimization, you should get a visualization as follows:
110
-
111
- <p align="center">
112
- <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
113
- </p>
114
-
115
- ## Running from within Python
122
+ ## Running from Another Python Application
116
123
 
117
- To run the Jax planner from within a Python application, refer to the following example:
124
+ To run JaxPlan from within a Python application, refer to the following example:
118
125
 
119
126
  ```python
120
127
  import pyRDDLGym
@@ -144,17 +151,15 @@ The basic structure of a configuration file is provided below for a straight-lin
144
151
  ```ini
145
152
  [Model]
146
153
  logic='FuzzyLogic'
147
- logic_kwargs={'weight': 20}
148
- tnorm='ProductTNorm'
149
- tnorm_kwargs={}
154
+ comparison_kwargs={'weight': 20}
155
+ rounding_kwargs={'weight': 20}
156
+ control_kwargs={'weight': 20}
150
157
 
151
158
  [Optimizer]
152
159
  method='JaxStraightLinePlan'
153
160
  method_kwargs={}
154
161
  optimizer='rmsprop'
155
162
  optimizer_kwargs={'learning_rate': 0.001}
156
- batch_size_train=1
157
- batch_size_test=1
158
163
 
159
164
  [Training]
160
165
  key=42
@@ -179,7 +184,7 @@ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
179
184
  ```
180
185
 
181
186
  The configuration file must then be passed to the planner during initialization.
182
- For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
187
+ For example, the [previous script here](#running-from-another-python-application) can be modified to set parameters from a config file:
183
188
 
184
189
  ```python
185
190
  from pyRDDLGym_jax.core.planner import load_config
@@ -193,6 +198,101 @@ controller = JaxOfflineController(planner, **train_args)
193
198
  ...
194
199
  ```
195
200
 
201
+ ### JaxPlan Dashboard
202
+
203
+ Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
204
+ and visualization of the policy or model, and other useful debugging features.
205
+
206
+ <p align="middle">
207
+ <img src="Images/dashboard.png" width="480" height="248" margin=0/>
208
+ </p>
209
+
210
+ To run the dashboard, add the following entry to your config file:
211
+
212
+ ```ini
213
+ ...
214
+ [Training]
215
+ dashboard=True
216
+ ...
217
+ ```
218
+
219
+ More documentation about this and other new features will be coming soon.
220
+
221
+ ### Tuning the Planner
222
+
223
+ It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
224
+ To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
225
+
226
+ ```ini
227
+ [Model]
228
+ logic='FuzzyLogic'
229
+ comparison_kwargs={'weight': TUNABLE_WEIGHT}
230
+ rounding_kwargs={'weight': TUNABLE_WEIGHT}
231
+ control_kwargs={'weight': TUNABLE_WEIGHT}
232
+
233
+ [Optimizer]
234
+ method='JaxStraightLinePlan'
235
+ method_kwargs={}
236
+ optimizer='rmsprop'
237
+ optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
238
+
239
+ [Training]
240
+ train_seconds=30
241
+ print_summary=False
242
+ print_progress=False
243
+ train_on_reset=True
244
+ ```
245
+
246
+ would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
247
+
248
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
249
+
250
+ ```python
251
+ import pyRDDLGym
252
+ from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
253
+
254
+ # set up the environment
255
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
256
+
257
+ # load the config file template with planner settings
258
+ with open('path/to/config.cfg', 'r') as file:
259
+ config_template = file.read()
260
+
261
+ # map parameters in the config that will be tuned
262
+ def power_10(x):
263
+ return 10.0 ** x
264
+
265
+ hyperparams = [
266
+ Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
267
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
268
+ ]
269
+
270
+ # build the tuner and tune
271
+ tuning = JaxParameterTuning(env=env,
272
+ config_template=config_template,
273
+ hyperparams=hyperparams,
274
+ online=False,
275
+ eval_trials=trials,
276
+ num_workers=workers,
277
+ gp_iters=iters)
278
+ tuning.tune(key=42, log_file='path/to/log.csv')
279
+ ```
280
+
281
+ A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
282
+
283
+ ```shell
284
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
285
+ ```
286
+
287
+ where:
288
+ - ``domain`` is the domain identifier as specified in rddlrepository
289
+ - ``instance`` is the instance identifier
290
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
291
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
292
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
293
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
294
+
295
+
196
296
  ## Simulation
197
297
 
198
298
  The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
@@ -206,47 +306,17 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
206
306
  env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
207
307
 
208
308
  # evaluate the random policy
209
- agent = RandomAgent(action_space=env.action_space,
210
- num_actions=env.max_allowed_actions)
309
+ agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
211
310
  agent.evaluate(env, verbose=True, render=True)
212
311
  ```
213
312
 
214
313
  For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
215
314
  In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
216
315
 
217
- ## Manual Gradient Calculation
218
-
219
- For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
220
- Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
221
-
222
- ```python
223
- import pyRDDLGym
224
- from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
225
-
226
- # set up the environment
227
- env = pyRDDLGym.make("domain", "instance", vectorized=True)
228
-
229
- # create the step function
230
- compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
231
- compiled.compile()
232
- step_fn = compiled.compile_transition()
233
- ```
234
-
235
- This will return a JAX compiled (pure) function requiring the following inputs:
236
- - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
237
- - ``actions`` is the dictionary of action fluent tensors
238
- - ``subs`` is the dictionary of state-fluent and non-fluent tensors
239
- - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
240
-
241
- The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
242
- It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
243
316
 
244
- Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
245
- An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
317
+ ## Citing JaxPlan
246
318
 
247
- ## Citing pyRDDLGym-jax
248
-
249
- The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
319
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of JaxPlan. Please cite it if you found it useful:
250
320
 
251
321
  ```
252
322
  @inproceedings{gimelfarb2024jaxplan,
@@ -258,21 +328,8 @@ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480
258
328
  }
259
329
  ```
260
330
 
261
- The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
262
-
263
- ```
264
- @inproceedings{patton2022distributional,
265
- title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
266
- author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
267
- booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
268
- volume={36},
269
- number={9},
270
- pages={9894--9901},
271
- year={2022}
272
- }
273
- ```
274
-
275
331
  Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
332
+ - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
276
333
  - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
277
334
  - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
278
335