pyRDDLGym-jax 0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. pyRDDLGym-jax-0.1/LICENSE +21 -0
  2. pyRDDLGym-jax-0.1/PKG-INFO +25 -0
  3. pyRDDLGym-jax-0.1/README.md +250 -0
  4. pyRDDLGym-jax-0.1/pyRDDLGym_jax/__init__.py +0 -0
  5. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/__init__.py +0 -0
  6. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/compiler.py +1781 -0
  7. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/logic.py +572 -0
  8. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/planner.py +1675 -0
  9. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/simulator.py +197 -0
  10. pyRDDLGym-jax-0.1/pyRDDLGym_jax/core/tuning.py +691 -0
  11. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/__init__.py +0 -0
  12. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_drp.cfg +19 -0
  13. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_replan.cfg +19 -0
  14. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_slp.cfg +19 -0
  15. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/HVAC_drp.cfg +18 -0
  16. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/HVAC_slp.cfg +18 -0
  17. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MarsRover_drp.cfg +18 -0
  18. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MarsRover_slp.cfg +20 -0
  19. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/MountainCar_slp.cfg +19 -0
  20. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Pendulum_slp.cfg +18 -0
  21. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Pong_slp.cfg +18 -0
  22. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_drp.cfg +18 -0
  23. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_replan.cfg +19 -0
  24. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/PowerGen_slp.cfg +18 -0
  25. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +18 -0
  26. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +19 -0
  27. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +18 -0
  28. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +18 -0
  29. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +20 -0
  30. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +18 -0
  31. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +18 -0
  32. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +20 -0
  33. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +19 -0
  34. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  35. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_gradient.py +102 -0
  36. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_gym.py +48 -0
  37. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_plan.py +61 -0
  38. pyRDDLGym-jax-0.1/pyRDDLGym_jax/examples/run_tune.py +75 -0
  39. pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/PKG-INFO +25 -0
  40. pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/SOURCES.txt +43 -0
  41. pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/dependency_links.txt +1 -0
  42. pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/requires.txt +8 -0
  43. pyRDDLGym-jax-0.1/pyRDDLGym_jax.egg-info/top_level.txt +1 -0
  44. pyRDDLGym-jax-0.1/setup.cfg +4 -0
  45. pyRDDLGym-jax-0.1/setup.py +49 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 pyrddlgym-project
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.1
2
+ Name: pyRDDLGym-jax
3
+ Version: 0.1
4
+ Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
5
+ Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
+ Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
+ Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
+ License: MIT License
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.8
17
+ License-File: LICENSE
18
+ Requires-Dist: pyRDDLGym>=2.0
19
+ Requires-Dist: tqdm>=4.66
20
+ Requires-Dist: bayesian-optimization>=1.4.3
21
+ Requires-Dist: jax>=0.4.12
22
+ Requires-Dist: optax>=0.1.9
23
+ Requires-Dist: dm-haiku>=0.0.10
24
+ Requires-Dist: tensorflow>=2.13.0
25
+ Requires-Dist: tensorflow-probability>=0.21.0
@@ -0,0 +1,250 @@
1
+ # pyRDDLGym-jax
2
+
3
+ Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
4
+
5
+ This directory provides:
6
+ 1. automated translation and compilation of RDDL description files into the [JAX](https://github.com/google/jax) auto-diff library, which allows any RDDL domain to be converted to a differentiable simulator!
7
+ 2. powerful, fast, and very scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
8
+
9
+ > [!NOTE]
10
+ > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
11
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
12
+
13
+ ## Contents
14
+
15
+ - [Installation](#installation)
16
+ - [Running the Basic Examples](#running-the-basic-examples)
17
+ - [Running from the Python API](#running-from-the-python-api)
18
+ - [Writing a Configuration File for a Custom Domain](#writing-a-configuration-file-for-a-custom-domain)
19
+ - [Changing the pyRDDLGym Simulation Backend to JAX](#changing-the-pyrddlgym-simulation-backend-to-jax)
20
+ - [Computing the Model Gradients Manually](#computing-the-model-gradients-manually)
21
+ - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
22
+
23
+ ## Installation
24
+
25
+ To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
26
+ - ``pyRDDLGym>=2.0``
27
+ - ``tqdm>=4.66``
28
+ - ``jax>=0.4.12``
29
+ - ``optax>=0.1.9``
30
+ - ``dm-haiku>=0.0.10``
31
+ - ``tensorflow>=2.13.0``
32
+ - ``tensorflow-probability>=0.21.0``
33
+
34
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``, and run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
35
+
36
+ You can install this package, together with all of its requirements as follows (assuming Anaconda):
37
+
38
+ ```shell
39
+ # Create a new conda environment
40
+ conda create -n jaxplan python=3.11
41
+ conda activate jaxplan
42
+ conda install pip git
43
+
44
+ # Manually install pyRDDLGym and rddlrepository
45
+ pip install git+https://github.com/pyrddlgym-project/pyRDDLGym
46
+ pip install git+https://github.com/pyrddlgym-project/rddlrepository
47
+
48
+ # Install pyRDDLGym-jax
49
+ pip install git+https://github.com/pyrddlgym-project/pyRDDLGym-jax
50
+ ```
51
+
52
+ A pip installer will be coming soon.
53
+
54
+ ## Running the Basic Examples
55
+
56
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, provided a config file is available (currently, only a limited subset of configs are provided as examples).
57
+ The example can be run as follows in a standard shell, from the install directory of pyRDDLGym-jax:
58
+
59
+ ```shell
60
+ python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
61
+ ```
62
+
63
+ where:
64
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
65
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
66
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
67
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy
68
+
69
+ The ``method`` parameter warrants further explanation. Currently we support three possible modes:
70
+ - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
71
+ - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
72
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
73
+
74
+ A basic run script is also provided to run the automatic hyper-parameter tuning. The structure of this stript is similar to the one above
75
+
76
+ ```shell
77
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
78
+ ```
79
+
80
+ where:
81
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
82
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
83
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
84
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
85
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
86
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
87
+
88
+ For example, copy and pasting the following will train the Jax Planner on the Quadcopter domain with 4 drones:
89
+
90
+ ```shell
91
+ python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
92
+ ```
93
+
94
+ After several minutes of optimization, you should get a visualization as follows:
95
+
96
+ <p align="center">
97
+ <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
98
+ </p>
99
+
100
+ ## Running from the Python API
101
+
102
+ You can write a simple script to instantiate and run the planner yourself, if you wish:
103
+
104
+ ```python
105
+ import pyRDDLGym
106
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
107
+
108
+ # set up the environment (note the vectorized option must be True)
109
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
110
+
111
+ # create the planning algorithm
112
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
113
+ controller = JaxOfflineController(planner, **train_args)
114
+
115
+ # evaluate the planner
116
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
117
+
118
+ env.close()
119
+ ```
120
+
121
+ Here, we have used the straight-line (offline) controller, although you can configure the combination of planner and policy representation.
122
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they support the ``evaluate()`` function to streamline interaction with the environment.
123
+ The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
124
+
125
+ ## Writing a Configuration File for a Custom Domain
126
+
127
+ The simplest way to interface with the Planner for solving a custom problem is to write a configuration file with all the necessary hyper-parameters.
128
+ The basic structure of a configuration file is provided below for a straight line planning/MPC style planner:
129
+
130
+ ```ini
131
+ [Model]
132
+ logic='FuzzyLogic'
133
+ logic_kwargs={'weight': 20}
134
+ tnorm='ProductTNorm'
135
+ tnorm_kwargs={}
136
+
137
+ [Optimizer]
138
+ method='JaxStraightLinePlan'
139
+ method_kwargs={}
140
+ optimizer='rmsprop'
141
+ optimizer_kwargs={'learning_rate': 0.001}
142
+ batch_size_train=1
143
+ batch_size_test=1
144
+
145
+ [Training]
146
+ key=42
147
+ epochs=5000
148
+ train_seconds=30
149
+ ```
150
+
151
+ The configuration file contains three sections:
152
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations, the ``weight`` parameter for example dictates how well the approximation will fit to the true operation,
153
+ and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
154
+ - ``[Optimizer]`` generally specify the optimizer and plan settings, the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the SGD optimizer to use from optax, learning rate, batch size, etc.
155
+ - ``[Training]`` specifies how long training should proceed, the ``epochs`` limits the total number of iterations, while ``train_seconds`` limits total training time
156
+
157
+ For a policy network approach, simply change the configuration file like so:
158
+
159
+ ```ini
160
+ ...
161
+ [Optimizer]
162
+ method='JaxDeepReactivePolicy'
163
+ method_kwargs={'topology': [128, 64]}
164
+ ...
165
+ ```
166
+
167
+ which would create a policy network with two hidden layers and ReLU activations.
168
+
169
+ The configuration file can then be passed to the planner during initialization. For example, the [previous script here](#running-from-the-python-api) can be modified to set parameters from a config file as follows:
170
+
171
+ ```python
172
+ from pyRDDLGym_jax.core.planner import load_config
173
+
174
+ # load the config file with planner settings
175
+ planner_args, _, train_args = load_config("/path/to/config.cfg")
176
+
177
+ # create the planning algorithm
178
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
179
+ controller = JaxOfflineController(planner, **train_args)
180
+ ...
181
+ ```
182
+
183
+ ## Changing the pyRDDLGym Simulation Backend to JAX
184
+
185
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments, instead of the usual pyRDDLGym one:
186
+
187
+ ```python
188
+ import pyRDDLGym
189
+ from pyRDDLGym.core.policy import RandomAgent
190
+ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
191
+
192
+ # create the environment
193
+ env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
194
+
195
+ # evaluate the random policy
196
+ agent = RandomAgent(action_space=env.action_space,
197
+ num_actions=env.max_allowed_actions)
198
+ agent.evaluate(env, verbose=True, render=True)
199
+ ```
200
+
201
+ For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
202
+ In any event, the simulation results using the JAX backend should match exactly those of the numpy-based backend.
203
+
204
+ ## Computing the Model Gradients Manually
205
+
206
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream. Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
207
+
208
+ ```python
209
+ import pyRDDLGym
210
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
211
+
212
+ # set up the environment
213
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
214
+
215
+ # create the step function
216
+ compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
217
+ compiled.compile()
218
+ step_fn = compiled.compile_transition()
219
+ ```
220
+
221
+ This will return a JAX compiled (pure) function that requires 4 arguments:
222
+ - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
223
+ - ``actions`` is the dictionary of action fluent JAX tensors
224
+ - ``subs`` is the dictionary of state-fluent and non-fluent JAX tensors
225
+ - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
226
+ -
227
+ The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
228
+ It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
229
+
230
+ Compilation of entire rollouts is possible by calling the ``compile_rollouts`` function, and providing a policy implementation that maps states (jax tensors) and tunable policy parameters to actions.
231
+ An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
232
+
233
+ ## Citing pyRDDLGym-jax
234
+
235
+ The main ideas of this approach are discussed in the following preprint:
236
+
237
+ ```
238
+ @article{taitler2022pyrddlgym,
239
+ title={pyRDDLGym: From RDDL to Gym Environments},
240
+ author={Taitler, Ayal and Gimelfarb, Michael and Gopalakrishnan, Sriram and Mladenov, Martin and Liu, Xiaotian and Sanner, Scott},
241
+ journal={arXiv preprint arXiv:2211.05939},
242
+ year={2022}
243
+ }
244
+ ```
245
+
246
+ Many of the implementation details discussed come from the following literature, which you may wish to cite in your research papers:
247
+ - [A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDP, AAAI 2022](https://ojs.aaai.org/index.php/AAAI/article/view/21226)
248
+ - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
249
+ - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
250
+
File without changes
File without changes