pyRDDLGym-jax 0.3__tar.gz → 0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pyrddlgym_jax-0.4/PKG-INFO +276 -0
  2. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/README.md +42 -44
  3. pyrddlgym_jax-0.4/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/compiler.py +90 -67
  5. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/logic.py +188 -46
  6. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/planner.py +59 -47
  7. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/simulator.py +2 -1
  8. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/tuning.py +7 -7
  9. pyrddlgym_jax-0.4/pyRDDLGym_jax.egg-info/PKG-INFO +276 -0
  10. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax.egg-info/requires.txt +0 -1
  11. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/setup.py +7 -3
  12. pyrddlgym_jax-0.3/PKG-INFO +0 -25
  13. pyrddlgym_jax-0.3/pyRDDLGym_jax/__init__.py +0 -1
  14. pyrddlgym_jax-0.3/pyRDDLGym_jax.egg-info/PKG-INFO +0 -25
  15. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/LICENSE +0 -0
  16. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/core/__init__.py +0 -0
  17. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/__init__.py +0 -0
  18. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  19. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  20. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  21. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  22. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  23. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -0
  24. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -0
  25. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  26. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  27. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -0
  28. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  29. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  30. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  31. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  32. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  33. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  34. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  35. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  36. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  37. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  38. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  39. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  40. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  41. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  42. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  43. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  44. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  45. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  46. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  47. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  48. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  49. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  50. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  51. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  52. {pyrddlgym_jax-0.3 → pyrddlgym_jax-0.4}/setup.cfg +0 -0
@@ -0,0 +1,276 @@
1
+ Metadata-Version: 2.1
2
+ Name: pyRDDLGym-jax
3
+ Version: 0.4
4
+ Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
+ Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
+ Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
+ Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
+ License: MIT License
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.8
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: pyRDDLGym>=2.0
20
+ Requires-Dist: tqdm>=4.66
21
+ Requires-Dist: bayesian-optimization>=1.4.3
22
+ Requires-Dist: jax>=0.4.12
23
+ Requires-Dist: optax>=0.1.9
24
+ Requires-Dist: dm-haiku>=0.0.10
25
+ Requires-Dist: tensorflow-probability>=0.21.0
26
+
27
+ # pyRDDLGym-jax
28
+
29
+ Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
30
+
31
+ This directory provides:
32
+ 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
33
+ 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
34
+
35
+ > [!NOTE]
36
+ > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
37
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
38
+
39
+ ## Contents
40
+
41
+ - [Installation](#installation)
42
+ - [Running from the Command Line](#running-from-the-command-line)
43
+ - [Running from within Python](#running-from-within-python)
44
+ - [Configuring the Planner](#configuring-the-planner)
45
+ - [Simulation](#simulation)
46
+ - [Manual Gradient Calculation](#manual-gradient-calculation)
47
+ - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
48
+
49
+ ## Installation
50
+
51
+ To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
52
+ - ``pyRDDLGym>=2.0``
53
+ - ``tqdm>=4.66``
54
+ - ``jax>=0.4.12``
55
+ - ``optax>=0.1.9``
56
+ - ``dm-haiku>=0.0.10``
57
+ - ``tensorflow-probability>=0.21.0``
58
+
59
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
60
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
61
+
62
+ You can install this package, together with all of its requirements, via pip:
63
+
64
+ ```shell
65
+ pip install rddlrepository pyRDDLGym-jax
66
+ ```
67
+
68
+ ## Running from the Command Line
69
+
70
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
71
+
72
+ ```shell
73
+ python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
74
+ ```
75
+
76
+ where:
77
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
78
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
79
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
80
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
81
+
82
+ The ``method`` parameter supports three possible modes:
83
+ - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
84
+ - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
85
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
86
+
87
+ A basic run script is also provided to run the automatic hyper-parameter tuning:
88
+
89
+ ```shell
90
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
91
+ ```
92
+
93
+ where:
94
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
95
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
96
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
97
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
98
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
99
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
100
+
101
+ For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
102
+
103
+ ```shell
104
+ python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
105
+ ```
106
+
107
+ After several minutes of optimization, you should get a visualization as follows:
108
+
109
+ <p align="center">
110
+ <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
111
+ </p>
112
+
113
+ ## Running from within Python
114
+
115
+ To run the Jax planner from within a Python application, refer to the following example:
116
+
117
+ ```python
118
+ import pyRDDLGym
119
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
120
+
121
+ # set up the environment (note the vectorized option must be True)
122
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
123
+
124
+ # create the planning algorithm
125
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
126
+ controller = JaxOfflineController(planner, **train_args)
127
+
128
+ # evaluate the planner
129
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
130
+ env.close()
131
+ ```
132
+
133
+ Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
134
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
135
+ The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
136
+
137
+ ## Configuring the Planner
138
+
139
+ The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
140
+ The basic structure of a configuration file is provided below for a straight-line planner:
141
+
142
+ ```ini
143
+ [Model]
144
+ logic='FuzzyLogic'
145
+ logic_kwargs={'weight': 20}
146
+ tnorm='ProductTNorm'
147
+ tnorm_kwargs={}
148
+
149
+ [Optimizer]
150
+ method='JaxStraightLinePlan'
151
+ method_kwargs={}
152
+ optimizer='rmsprop'
153
+ optimizer_kwargs={'learning_rate': 0.001}
154
+ batch_size_train=1
155
+ batch_size_test=1
156
+
157
+ [Training]
158
+ key=42
159
+ epochs=5000
160
+ train_seconds=30
161
+ ```
162
+
163
+ The configuration file contains three sections:
164
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
165
+ and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
166
+ - ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
167
+ - ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
168
+
169
+ For a policy network approach, simply change the ``[Optimizer]`` settings like so:
170
+
171
+ ```ini
172
+ ...
173
+ [Optimizer]
174
+ method='JaxDeepReactivePolicy'
175
+ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
176
+ ...
177
+ ```
178
+
179
+ The configuration file must then be passed to the planner during initialization.
180
+ For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
181
+
182
+ ```python
183
+ from pyRDDLGym_jax.core.planner import load_config
184
+
185
+ # load the config file with planner settings
186
+ planner_args, _, train_args = load_config("/path/to/config.cfg")
187
+
188
+ # create the planning algorithm
189
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
190
+ controller = JaxOfflineController(planner, **train_args)
191
+ ...
192
+ ```
193
+
194
+ ## Simulation
195
+
196
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
197
+
198
+ ```python
199
+ import pyRDDLGym
200
+ from pyRDDLGym.core.policy import RandomAgent
201
+ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
202
+
203
+ # create the environment
204
+ env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
205
+
206
+ # evaluate the random policy
207
+ agent = RandomAgent(action_space=env.action_space,
208
+ num_actions=env.max_allowed_actions)
209
+ agent.evaluate(env, verbose=True, render=True)
210
+ ```
211
+
212
+ For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
213
+ In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
214
+
215
+ ## Manual Gradient Calculation
216
+
217
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
218
+ Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
219
+
220
+ ```python
221
+ import pyRDDLGym
222
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
223
+
224
+ # set up the environment
225
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
226
+
227
+ # create the step function
228
+ compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
229
+ compiled.compile()
230
+ step_fn = compiled.compile_transition()
231
+ ```
232
+
233
+ This will return a JAX compiled (pure) function requiring the following inputs:
234
+ - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
235
+ - ``actions`` is the dictionary of action fluent tensors
236
+ - ``subs`` is the dictionary of state-fluent and non-fluent tensors
237
+ - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
238
+
239
+ The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
240
+ It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
241
+
242
+ Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
243
+ An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
244
+
245
+ ## Citing pyRDDLGym-jax
246
+
247
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
248
+
249
+ ```
250
+ @inproceedings{gimelfarb2024jaxplan,
251
+ title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
252
+ author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
253
+ booktitle={34th International Conference on Automated Planning and Scheduling},
254
+ year={2024},
255
+ url={https://openreview.net/forum?id=7IKtmUpLEH}
256
+ }
257
+ ```
258
+
259
+ The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
260
+
261
+ ```
262
+ @inproceedings{patton2022distributional,
263
+ title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
264
+ author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
265
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
266
+ volume={36},
267
+ number={9},
268
+ pages={9894--9901},
269
+ year={2022}
270
+ }
271
+ ```
272
+
273
+ Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
274
+ - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
275
+ - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
276
+
@@ -3,8 +3,8 @@
3
3
  Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
4
4
 
5
5
  This directory provides:
6
- 1. automated translation and compilation of RDDL description files into the [JAX](https://github.com/google/jax) auto-diff library, which allows any RDDL domain to be converted to a differentiable simulator!
7
- 2. powerful, fast, and very scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
6
+ 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
7
+ 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
8
8
 
9
9
  > [!NOTE]
10
10
  > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
@@ -13,11 +13,11 @@ This directory provides:
13
13
  ## Contents
14
14
 
15
15
  - [Installation](#installation)
16
- - [Running the Basic Examples](#running-the-basic-examples)
17
- - [Running from the Python API](#running-from-the-python-api)
18
- - [Writing a Configuration File for a Custom Domain](#writing-a-configuration-file-for-a-custom-domain)
19
- - [Changing the pyRDDLGym Simulation Backend to JAX](#changing-the-pyrddlgym-simulation-backend-to-jax)
20
- - [Computing the Model Gradients Manually](#computing-the-model-gradients-manually)
16
+ - [Running from the Command Line](#running-from-the-command-line)
17
+ - [Running from within Python](#running-from-within-python)
18
+ - [Configuring the Planner](#configuring-the-planner)
19
+ - [Simulation](#simulation)
20
+ - [Manual Gradient Calculation](#manual-gradient-calculation)
21
21
  - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
22
22
 
23
23
  ## Installation
@@ -28,21 +28,20 @@ To use the compiler or planner without the automated hyper-parameter tuning, you
28
28
  - ``jax>=0.4.12``
29
29
  - ``optax>=0.1.9``
30
30
  - ``dm-haiku>=0.0.10``
31
- - ``tensorflow>=2.13.0``
32
31
  - ``tensorflow-probability>=0.21.0``
33
32
 
34
- Additionally, if you wish to run the examples, you need ``rddlrepository>=2``, and run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
33
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
34
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
35
35
 
36
- You can install this package, together with all of its requirements via pip:
36
+ You can install this package, together with all of its requirements, via pip:
37
37
 
38
38
  ```shell
39
39
  pip install rddlrepository pyRDDLGym-jax
40
40
  ```
41
41
 
42
- ## Running the Basic Examples
42
+ ## Running from the Command Line
43
43
 
44
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, provided a config file is available (currently, only a limited subset of configs are provided as examples).
45
- The example can be run as follows in a standard shell, from the install directory of pyRDDLGym-jax:
44
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
46
45
 
47
46
  ```shell
48
47
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -52,14 +51,14 @@ where:
52
51
  - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
53
52
  - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
54
53
  - ``method`` is the planning method to use (i.e. drp, slp, replan)
55
- - ``episodes`` is the (optional) number of episodes to evaluate the learned policy
54
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
56
55
 
57
- The ``method`` parameter warrants further explanation. Currently we support three possible modes:
56
+ The ``method`` parameter supports three possible modes:
58
57
  - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
59
58
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
60
59
  - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
61
60
 
62
- A basic run script is also provided to run the automatic hyper-parameter tuning. The structure of this stript is similar to the one above
61
+ A basic run script is also provided to run the automatic hyper-parameter tuning:
63
62
 
64
63
  ```shell
65
64
  python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
@@ -71,9 +70,9 @@ where:
71
70
  - ``method`` is the planning method to use (i.e. drp, slp, replan)
72
71
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
73
72
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
74
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
73
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
75
74
 
76
- For example, copy and pasting the following will train the Jax Planner on the Quadcopter domain with 4 drones:
75
+ For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
77
76
 
78
77
  ```shell
79
78
  python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
@@ -85,9 +84,9 @@ After several minutes of optimization, you should get a visualization as follows
85
84
  <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
86
85
  </p>
87
86
 
88
- ## Running from the Python API
87
+ ## Running from within Python
89
88
 
90
- You can write a simple script to instantiate and run the planner yourself, if you wish:
89
+ To run the Jax planner from within a Python application, refer to the following example:
91
90
 
92
91
  ```python
93
92
  import pyRDDLGym
@@ -102,18 +101,17 @@ controller = JaxOfflineController(planner, **train_args)
102
101
 
103
102
  # evaluate the planner
104
103
  controller.evaluate(env, episodes=1, verbose=True, render=True)
105
-
106
104
  env.close()
107
105
  ```
108
106
 
109
- Here, we have used the straight-line (offline) controller, although you can configure the combination of planner and policy representation.
110
- All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they support the ``evaluate()`` function to streamline interaction with the environment.
107
+ Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
108
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
111
109
  The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
112
110
 
113
- ## Writing a Configuration File for a Custom Domain
111
+ ## Configuring the Planner
114
112
 
115
- The simplest way to interface with the Planner for solving a custom problem is to write a configuration file with all the necessary hyper-parameters.
116
- The basic structure of a configuration file is provided below for a straight line planning/MPC style planner:
113
+ The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
114
+ The basic structure of a configuration file is provided below for a straight-line planner:
117
115
 
118
116
  ```ini
119
117
  [Model]
@@ -137,24 +135,23 @@ train_seconds=30
137
135
  ```
138
136
 
139
137
  The configuration file contains three sections:
140
- - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations, the ``weight`` parameter for example dictates how well the approximation will fit to the true operation,
138
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
141
139
  and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
142
- - ``[Optimizer]`` generally specify the optimizer and plan settings, the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the SGD optimizer to use from optax, learning rate, batch size, etc.
143
- - ``[Training]`` specifies how long training should proceed, the ``epochs`` limits the total number of iterations, while ``train_seconds`` limits total training time
140
+ - ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
141
+ - ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
144
142
 
145
- For a policy network approach, simply change the configuration file like so:
143
+ For a policy network approach, simply change the ``[Optimizer]`` settings like so:
146
144
 
147
145
  ```ini
148
146
  ...
149
147
  [Optimizer]
150
148
  method='JaxDeepReactivePolicy'
151
- method_kwargs={'topology': [128, 64]}
149
+ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
152
150
  ...
153
151
  ```
154
152
 
155
- which would create a policy network with two hidden layers and ReLU activations.
156
-
157
- The configuration file can then be passed to the planner during initialization. For example, the [previous script here](#running-from-the-python-api) can be modified to set parameters from a config file as follows:
153
+ The configuration file must then be passed to the planner during initialization.
154
+ For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
158
155
 
159
156
  ```python
160
157
  from pyRDDLGym_jax.core.planner import load_config
@@ -168,9 +165,9 @@ controller = JaxOfflineController(planner, **train_args)
168
165
  ...
169
166
  ```
170
167
 
171
- ## Changing the pyRDDLGym Simulation Backend to JAX
168
+ ## Simulation
172
169
 
173
- The JAX compiler can be used as a backend for simulating and evaluating RDDL environments, instead of the usual pyRDDLGym one:
170
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
174
171
 
175
172
  ```python
176
173
  import pyRDDLGym
@@ -187,11 +184,12 @@ agent.evaluate(env, verbose=True, render=True)
187
184
  ```
188
185
 
189
186
  For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
190
- In any event, the simulation results using the JAX backend should match exactly those of the numpy-based backend.
187
+ In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
191
188
 
192
- ## Computing the Model Gradients Manually
189
+ ## Manual Gradient Calculation
193
190
 
194
- For custom applications, it is desirable to compute gradients of the model that can be optimized downstream. Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
191
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
192
+ Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
195
193
 
196
194
  ```python
197
195
  import pyRDDLGym
@@ -206,16 +204,16 @@ compiled.compile()
206
204
  step_fn = compiled.compile_transition()
207
205
  ```
208
206
 
209
- This will return a JAX compiled (pure) function that requires 4 arguments:
207
+ This will return a JAX compiled (pure) function requiring the following inputs:
210
208
  - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
211
- - ``actions`` is the dictionary of action fluent JAX tensors
212
- - ``subs`` is the dictionary of state-fluent and non-fluent JAX tensors
209
+ - ``actions`` is the dictionary of action fluent tensors
210
+ - ``subs`` is the dictionary of state-fluent and non-fluent tensors
213
211
  - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
214
- -
212
+
215
213
  The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
216
214
  It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
217
215
 
218
- Compilation of entire rollouts is possible by calling the ``compile_rollouts`` function, and providing a policy implementation that maps states (jax tensors) and tunable policy parameters to actions.
216
+ Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
219
217
  An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
220
218
 
221
219
  ## Citing pyRDDLGym-jax
@@ -0,0 +1 @@
1
+ __version__ = '0.4'