pyRDDLGym-jax 0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym-jax-0.1/LICENSE +21 -0
- pyRDDLGym-jax-0.1/PKG-INFO +25 -0
- pyRDDLGym-jax-0.1/README.md +250 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/__init__.py +0 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/__init__.py +0 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/compiler.py +1781 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/logic.py +572 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/planner.py +1675 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/simulator.py +197 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/tuning.py +691 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/__init__.py +0 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_drp.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_replan.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_slp.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/HVAC_drp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/HVAC_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MarsRover_drp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MarsRover_slp.cfg +20 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MountainCar_slp.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Pendulum_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Pong_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_drp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_replan.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +20 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +18 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +20 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +19 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_gradient.py +102 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_gym.py +48 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_plan.py +61 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_tune.py +75 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/PKG-INFO +25 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/SOURCES.txt +43 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/dependency_links.txt +1 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/requires.txt +8 -0
- pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/top_level.txt +1 -0
- pyRDDLGym-jax-0.1/setup.cfg +4 -0
- pyRDDLGym-jax-0.1/setup.py +49 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 pyrddlgym-project
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: pyRDDLGym-jax
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
|
|
5
|
+
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
|
+
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
|
+
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
|
+
License: MIT License
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Natural Language :: English
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: pyRDDLGym>=2.0
|
|
19
|
+
Requires-Dist: tqdm>=4.66
|
|
20
|
+
Requires-Dist: bayesian-optimization>=1.4.3
|
|
21
|
+
Requires-Dist: jax>=0.4.12
|
|
22
|
+
Requires-Dist: optax>=0.1.9
|
|
23
|
+
Requires-Dist: dm-haiku>=0.0.10
|
|
24
|
+
Requires-Dist: tensorflow>=2.13.0
|
|
25
|
+
Requires-Dist: tensorflow-probability>=0.21.0
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# pyRDDLGym-jax
|
|
2
|
+
|
|
3
|
+
Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
|
|
4
|
+
|
|
5
|
+
This directory provides:
|
|
6
|
+
1. automated translation and compilation of RDDL description files into the [JAX](https://github.com/google/jax) auto-diff library, which allows any RDDL domain to be converted to a differentiable simulator!
|
|
7
|
+
2. powerful, fast, and very scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
|
|
8
|
+
|
|
9
|
+
> [!NOTE]
|
|
10
|
+
> While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
11
|
+
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
12
|
+
|
|
13
|
+
## Contents
|
|
14
|
+
|
|
15
|
+
- [Installation](#installation)
|
|
16
|
+
- [Running the Basic Examples](#running-the-basic-examples)
|
|
17
|
+
- [Running from the Python API](#running-from-the-python-api)
|
|
18
|
+
- [Writing a Configuration File for a Custom Domain](#writing-a-configuration-file-for-a-custom-domain)
|
|
19
|
+
- [Changing the pyRDDLGym Simulation Backend to JAX](#changing-the-pyrddlgym-simulation-backend-to-jax)
|
|
20
|
+
- [Computing the Model Gradients Manually](#computing-the-model-gradients-manually)
|
|
21
|
+
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
|
|
26
|
+
- ``pyRDDLGym>=2.0``
|
|
27
|
+
- ``tqdm>=4.66``
|
|
28
|
+
- ``jax>=0.4.12``
|
|
29
|
+
- ``optax>=0.1.9``
|
|
30
|
+
- ``dm-haiku>=0.0.10``
|
|
31
|
+
- ``tensorflow>=2.13.0``
|
|
32
|
+
- ``tensorflow-probability>=0.21.0``
|
|
33
|
+
|
|
34
|
+
Additionally, if you wish to run the examples, you need ``rddlrepository>=2``, and run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
|
|
35
|
+
|
|
36
|
+
You can install this package, together with all of its requirements as follows (assuming Anaconda):
|
|
37
|
+
|
|
38
|
+
```shell
|
|
39
|
+
# Create a new conda environment
|
|
40
|
+
conda create -n jaxplan python=3.11
|
|
41
|
+
conda activate jaxplan
|
|
42
|
+
conda install pip git
|
|
43
|
+
|
|
44
|
+
# Manually install pyRDDLGym and rddlrepository
|
|
45
|
+
pip install git+https://github.com/pyrddlgym-project/pyRDDLGym
|
|
46
|
+
pip install git+https://github.com/pyrddlgym-project/rddlrepository
|
|
47
|
+
|
|
48
|
+
# Install pyRDDLGym-jax
|
|
49
|
+
pip install git+https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
A pip installer will be coming soon.
|
|
53
|
+
|
|
54
|
+
## Running the Basic Examples
|
|
55
|
+
|
|
56
|
+
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, provided a config file is available (currently, only a limited subset of configs are provided as examples).
|
|
57
|
+
The example can be run as follows in a standard shell, from the install directory of pyRDDLGym-jax:
|
|
58
|
+
|
|
59
|
+
```shell
|
|
60
|
+
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
where:
|
|
64
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
65
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
66
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
67
|
+
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy
|
|
68
|
+
|
|
69
|
+
The ``method`` parameter warrants further explanation. Currently we support three possible modes:
|
|
70
|
+
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
71
|
+
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
72
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
73
|
+
|
|
74
|
+
A basic run script is also provided to run the automatic hyper-parameter tuning. The structure of this stript is similar to the one above
|
|
75
|
+
|
|
76
|
+
```shell
|
|
77
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
where:
|
|
81
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
|
|
82
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
|
|
83
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
84
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
85
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
86
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
87
|
+
|
|
88
|
+
For example, copy and pasting the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
89
|
+
|
|
90
|
+
```shell
|
|
91
|
+
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
After several minutes of optimization, you should get a visualization as follows:
|
|
95
|
+
|
|
96
|
+
<p align="center">
|
|
97
|
+
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
98
|
+
</p>
|
|
99
|
+
|
|
100
|
+
## Running from the Python API
|
|
101
|
+
|
|
102
|
+
You can write a simple script to instantiate and run the planner yourself, if you wish:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
import pyRDDLGym
|
|
106
|
+
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
|
|
107
|
+
|
|
108
|
+
# set up the environment (note the vectorized option must be True)
|
|
109
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
110
|
+
|
|
111
|
+
# create the planning algorithm
|
|
112
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
113
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
114
|
+
|
|
115
|
+
# evaluate the planner
|
|
116
|
+
controller.evaluate(env, episodes=1, verbose=True, render=True)
|
|
117
|
+
|
|
118
|
+
env.close()
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
Here, we have used the straight-line (offline) controller, although you can configure the combination of planner and policy representation.
|
|
122
|
+
All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they support the ``evaluate()`` function to streamline interaction with the environment.
|
|
123
|
+
The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
|
|
124
|
+
|
|
125
|
+
## Writing a Configuration File for a Custom Domain
|
|
126
|
+
|
|
127
|
+
The simplest way to interface with the Planner for solving a custom problem is to write a configuration file with all the necessary hyper-parameters.
|
|
128
|
+
The basic structure of a configuration file is provided below for a straight line planning/MPC style planner:
|
|
129
|
+
|
|
130
|
+
```ini
|
|
131
|
+
[Model]
|
|
132
|
+
logic='FuzzyLogic'
|
|
133
|
+
logic_kwargs={'weight': 20}
|
|
134
|
+
tnorm='ProductTNorm'
|
|
135
|
+
tnorm_kwargs={}
|
|
136
|
+
|
|
137
|
+
[Optimizer]
|
|
138
|
+
method='JaxStraightLinePlan'
|
|
139
|
+
method_kwargs={}
|
|
140
|
+
optimizer='rmsprop'
|
|
141
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
142
|
+
batch_size_train=1
|
|
143
|
+
batch_size_test=1
|
|
144
|
+
|
|
145
|
+
[Training]
|
|
146
|
+
key=42
|
|
147
|
+
epochs=5000
|
|
148
|
+
train_seconds=30
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
The configuration file contains three sections:
|
|
152
|
+
- ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations, the ``weight`` parameter for example dictates how well the approximation will fit to the true operation,
|
|
153
|
+
and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
|
|
154
|
+
- ``[Optimizer]`` generally specify the optimizer and plan settings, the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the SGD optimizer to use from optax, learning rate, batch size, etc.
|
|
155
|
+
- ``[Training]`` specifies how long training should proceed, the ``epochs`` limits the total number of iterations, while ``train_seconds`` limits total training time
|
|
156
|
+
|
|
157
|
+
For a policy network approach, simply change the configuration file like so:
|
|
158
|
+
|
|
159
|
+
```ini
|
|
160
|
+
...
|
|
161
|
+
[Optimizer]
|
|
162
|
+
method='JaxDeepReactivePolicy'
|
|
163
|
+
method_kwargs={'topology': [128, 64]}
|
|
164
|
+
...
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
which would create a policy network with two hidden layers and ReLU activations.
|
|
168
|
+
|
|
169
|
+
The configuration file can then be passed to the planner during initialization. For example, the [previous script here](#running-from-the-python-api) can be modified to set parameters from a config file as follows:
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
from pyRDDLGym_jax.core.planner import load_config
|
|
173
|
+
|
|
174
|
+
# load the config file with planner settings
|
|
175
|
+
planner_args, _, train_args = load_config("/path/to/config.cfg")
|
|
176
|
+
|
|
177
|
+
# create the planning algorithm
|
|
178
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
179
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
180
|
+
...
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
## Changing the pyRDDLGym Simulation Backend to JAX
|
|
184
|
+
|
|
185
|
+
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments, instead of the usual pyRDDLGym one:
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
import pyRDDLGym
|
|
189
|
+
from pyRDDLGym.core.policy import RandomAgent
|
|
190
|
+
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
191
|
+
|
|
192
|
+
# create the environment
|
|
193
|
+
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
194
|
+
|
|
195
|
+
# evaluate the random policy
|
|
196
|
+
agent = RandomAgent(action_space=env.action_space,
|
|
197
|
+
num_actions=env.max_allowed_actions)
|
|
198
|
+
agent.evaluate(env, verbose=True, render=True)
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
202
|
+
In any event, the simulation results using the JAX backend should match exactly those of the numpy-based backend.
|
|
203
|
+
|
|
204
|
+
## Computing the Model Gradients Manually
|
|
205
|
+
|
|
206
|
+
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream. Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
207
|
+
|
|
208
|
+
```python
|
|
209
|
+
import pyRDDLGym
|
|
210
|
+
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
211
|
+
|
|
212
|
+
# set up the environment
|
|
213
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
214
|
+
|
|
215
|
+
# create the step function
|
|
216
|
+
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
217
|
+
compiled.compile()
|
|
218
|
+
step_fn = compiled.compile_transition()
|
|
219
|
+
```
|
|
220
|
+
|
|
221
|
+
This will return a JAX compiled (pure) function that requires 4 arguments:
|
|
222
|
+
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
223
|
+
- ``actions`` is the dictionary of action fluent JAX tensors
|
|
224
|
+
- ``subs`` is the dictionary of state-fluent and non-fluent JAX tensors
|
|
225
|
+
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
226
|
+
-
|
|
227
|
+
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
228
|
+
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
229
|
+
|
|
230
|
+
Compilation of entire rollouts is possible by calling the ``compile_rollouts`` function, and providing a policy implementation that maps states (jax tensors) and tunable policy parameters to actions.
|
|
231
|
+
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
232
|
+
|
|
233
|
+
## Citing pyRDDLGym-jax
|
|
234
|
+
|
|
235
|
+
The main ideas of this approach are discussed in the following preprint:
|
|
236
|
+
|
|
237
|
+
```
|
|
238
|
+
@article{taitler2022pyrddlgym,
|
|
239
|
+
title={pyRDDLGym: From RDDL to Gym Environments},
|
|
240
|
+
author={Taitler, Ayal and Gimelfarb, Michael and Gopalakrishnan, Sriram and Mladenov, Martin and Liu, Xiaotian and Sanner, Scott},
|
|
241
|
+
journal={arXiv preprint arXiv:2211.05939},
|
|
242
|
+
year={2022}
|
|
243
|
+
}
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
Many of the implementation details discussed come from the following literature, which you may wish to cite in your research papers:
|
|
247
|
+
- [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDP, AAAI 2022](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
|
|
248
|
+
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
249
|
+
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
250
|
+
|
|
File without changes
|
|
File without changes
|