pyRDDLGym-jax 1.0__tar.gz → 1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/PKG-INFO +40 -25
  2. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/README.md +22 -22
  3. pyrddlgym_jax-1.2/pyRDDLGym_jax/__init__.py +1 -0
  4. pyrddlgym_jax-1.2/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  5. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/compiler.py +60 -60
  6. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/planner.py +8 -10
  7. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/tuning.py +20 -6
  8. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/visualization.py +1 -1
  9. pyrddlgym_jax-1.2/pyRDDLGym_jax/entry_point.py +27 -0
  10. pyrddlgym_jax-1.2/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  11. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_plan.py +20 -13
  12. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_tune.py +5 -3
  13. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/PKG-INFO +40 -25
  14. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/SOURCES.txt +4 -0
  15. pyrddlgym_jax-1.2/pyRDDLGym_jax.egg-info/entry_points.txt +2 -0
  16. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/setup.py +10 -3
  17. pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +0 -1
  18. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/LICENSE +0 -0
  19. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/__init__.py +0 -0
  20. {pyrddlgym_jax-1.0/pyRDDLGym_jax/examples → pyrddlgym_jax-1.2/pyRDDLGym_jax/core/assets}/__init__.py +0 -0
  21. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/logic.py +0 -0
  22. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/simulator.py +0 -0
  23. {pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs → pyrddlgym_jax-1.2/pyRDDLGym_jax/examples}/__init__.py +0 -0
  24. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  25. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  26. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  27. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  28. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  29. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  30. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  31. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  32. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  33. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  34. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  35. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  36. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  37. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  38. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  39. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  40. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  41. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  42. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  43. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  44. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  45. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  46. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  47. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  48. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  49. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  50. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  51. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  52. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  53. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  54. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  55. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/setup.cfg +0 -0
@@ -1,17 +1,21 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.0
3
+ Version: 1.2
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
7
  Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
8
  License: MIT License
9
- Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Natural Language :: English
13
13
  Classifier: Operating System :: OS Independent
14
14
  Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
15
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
20
  Requires-Python: >=3.9
17
21
  Description-Content-Type: text/markdown
@@ -28,10 +32,31 @@ Requires-Dist: rddlrepository>=2.0; extra == "extra"
28
32
  Provides-Extra: dashboard
29
33
  Requires-Dist: dash>=2.18.0; extra == "dashboard"
30
34
  Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
35
+ Dynamic: author
36
+ Dynamic: author-email
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: license
42
+ Dynamic: provides-extra
43
+ Dynamic: requires-dist
44
+ Dynamic: requires-python
45
+ Dynamic: summary
31
46
 
32
47
  # pyRDDLGym-jax
33
48
 
34
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
49
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
50
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
51
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
52
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
53
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
54
+
55
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
56
+
57
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
58
+
59
+ Purpose:
35
60
 
36
61
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
62
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -56,17 +81,6 @@ and was moved to the individual logic components which have their own unique wei
56
81
  > [!NOTE]
57
82
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
58
83
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
59
-
60
- ## Contents
61
-
62
- - [Installation](#installation)
63
- - [Running from the Command Line](#running-from-the-command-line)
64
- - [Running from Another Python Application](#running-from-another-python-application)
65
- - [Configuring the Planner](#configuring-the-planner)
66
- - [JaxPlan Dashboard](#jaxplan-dashboard)
67
- - [Tuning the Planner](#tuning-the-planner)
68
- - [Simulation](#simulation)
69
- - [Citing JaxPlan](#citing-jaxplan)
70
84
 
71
85
  ## Installation
72
86
 
@@ -96,27 +110,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
96
110
 
97
111
  ## Running from the Command Line
98
112
 
99
- A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
113
+ A basic run script is provided to train JaxPlan on any RDDL problem:
100
114
 
101
115
  ```shell
102
- python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
116
+ jaxplan plan <domain> <instance> <method> <episodes>
103
117
  ```
104
118
 
105
119
  where:
106
120
  - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
107
121
  - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
108
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
122
+ - ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
109
123
  - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
110
124
 
111
- The ``method`` parameter supports three possible modes:
125
+ The ``method`` parameter supports four possible modes:
112
126
  - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
113
127
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
114
- - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
128
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
129
+ - any other argument is interpreted as a file path to a valid configuration file.
115
130
 
116
- For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
131
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
117
132
 
118
133
  ```shell
119
- python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
134
+ jaxplan plan Quadcopter 1 slp
120
135
  ```
121
136
 
122
137
  ## Running from Another Python Application
@@ -198,7 +213,7 @@ controller = JaxOfflineController(planner, **train_args)
198
213
  ...
199
214
  ```
200
215
 
201
- ### JaxPlan Dashboard
216
+ ## JaxPlan Dashboard
202
217
 
203
218
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
204
219
  and visualization of the policy or model, and other useful debugging features.
@@ -218,7 +233,7 @@ dashboard=True
218
233
 
219
234
  More documentation about this and other new features will be coming soon.
220
235
 
221
- ### Tuning the Planner
236
+ ## Tuning the Planner
222
237
 
223
238
  It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
224
239
  To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -281,7 +296,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
281
296
  A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
282
297
 
283
298
  ```shell
284
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
299
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
285
300
  ```
286
301
 
287
302
  where:
@@ -1,6 +1,16 @@
1
1
  # pyRDDLGym-jax
2
2
 
3
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
3
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
4
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
5
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
6
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
7
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
8
+
9
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
10
+
11
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
12
+
13
+ Purpose:
4
14
 
5
15
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
6
16
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -25,17 +35,6 @@ and was moved to the individual logic components which have their own unique wei
25
35
  > [!NOTE]
26
36
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
27
37
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
28
-
29
- ## Contents
30
-
31
- - [Installation](#installation)
32
- - [Running from the Command Line](#running-from-the-command-line)
33
- - [Running from Another Python Application](#running-from-another-python-application)
34
- - [Configuring the Planner](#configuring-the-planner)
35
- - [JaxPlan Dashboard](#jaxplan-dashboard)
36
- - [Tuning the Planner](#tuning-the-planner)
37
- - [Simulation](#simulation)
38
- - [Citing JaxPlan](#citing-jaxplan)
39
38
 
40
39
  ## Installation
41
40
 
@@ -65,27 +64,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
65
64
 
66
65
  ## Running from the Command Line
67
66
 
68
- A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
67
+ A basic run script is provided to train JaxPlan on any RDDL problem:
69
68
 
70
69
  ```shell
71
- python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
70
+ jaxplan plan <domain> <instance> <method> <episodes>
72
71
  ```
73
72
 
74
73
  where:
75
74
  - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
76
75
  - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
77
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
76
+ - ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
78
77
  - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
79
78
 
80
- The ``method`` parameter supports three possible modes:
79
+ The ``method`` parameter supports four possible modes:
81
80
  - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
82
81
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
83
- - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
82
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
83
+ - any other argument is interpreted as a file path to a valid configuration file.
84
84
 
85
- For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
85
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
86
86
 
87
87
  ```shell
88
- python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
88
+ jaxplan plan Quadcopter 1 slp
89
89
  ```
90
90
 
91
91
  ## Running from Another Python Application
@@ -167,7 +167,7 @@ controller = JaxOfflineController(planner, **train_args)
167
167
  ...
168
168
  ```
169
169
 
170
- ### JaxPlan Dashboard
170
+ ## JaxPlan Dashboard
171
171
 
172
172
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
173
173
  and visualization of the policy or model, and other useful debugging features.
@@ -187,7 +187,7 @@ dashboard=True
187
187
 
188
188
  More documentation about this and other new features will be coming soon.
189
189
 
190
- ### Tuning the Planner
190
+ ## Tuning the Planner
191
191
 
192
192
  It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
193
193
  To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -250,7 +250,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
250
250
  A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
251
251
 
252
252
  ```shell
253
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
253
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
254
254
  ```
255
255
 
256
256
  where:
@@ -0,0 +1 @@
1
+ __version__ = '1.2'
@@ -51,65 +51,65 @@ class JaxRDDLCompiler:
51
51
  return func
52
52
  return exact_func
53
53
 
54
- EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
54
+ EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
55
55
  EXACT_RDDL_TO_JAX_ARITHMETIC = {
56
- '+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
57
- '-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
58
- '*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
59
- '/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
56
+ '+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
57
+ '-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
58
+ '*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
59
+ '/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
60
60
  }
61
61
 
62
62
  EXACT_RDDL_TO_JAX_RELATIONAL = {
63
- '>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
- '<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
65
- '<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
66
- '>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
67
- '==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
68
- '~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
63
+ '>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
+ '<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
65
+ '<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
66
+ '>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
67
+ '==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
68
+ '~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
69
69
  }
70
70
 
71
- EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
71
+ EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
72
72
  EXACT_RDDL_TO_JAX_LOGICAL = {
73
- '^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
74
- '&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
75
- '|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
76
- '~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
- '=>': wrap_logic(ExactLogic.exact_binary_implies),
78
- '<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
73
+ '^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
74
+ '&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
75
+ '|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
76
+ '~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
+ '=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
78
+ '<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
79
79
  }
80
80
 
81
81
  EXACT_RDDL_TO_JAX_AGGREGATION = {
82
- 'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
83
- 'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
84
- 'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
85
- 'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
86
- 'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
87
- 'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
88
- 'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
89
- 'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
90
- 'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
82
+ 'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
83
+ 'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
84
+ 'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
85
+ 'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
86
+ 'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
87
+ 'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
88
+ 'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
89
+ 'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
90
+ 'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
91
91
  }
92
92
 
93
93
  EXACT_RDDL_TO_JAX_UNARY = {
94
- 'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
95
- 'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
96
- 'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
97
- 'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
98
- 'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
99
- 'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
100
- 'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
101
- 'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
102
- 'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
103
- 'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
104
- 'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
105
- 'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
106
- 'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
107
- 'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
108
- 'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
109
- 'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
110
- 'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
111
- 'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
- 'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
94
+ 'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
95
+ 'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
96
+ 'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
97
+ 'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
98
+ 'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
99
+ 'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
100
+ 'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
101
+ 'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
102
+ 'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
103
+ 'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
104
+ 'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
105
+ 'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
106
+ 'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
107
+ 'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
108
+ 'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
109
+ 'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
110
+ 'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
111
+ 'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
+ 'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
113
113
  }
114
114
 
115
115
  @staticmethod
@@ -117,23 +117,23 @@ class JaxRDDLCompiler:
117
117
  return jnp.log(x) / jnp.log(y), params
118
118
 
119
119
  EXACT_RDDL_TO_JAX_BINARY = {
120
- 'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
- 'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
122
- 'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
123
- 'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
124
- 'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
125
- 'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
126
- 'log': wrap_logic(_jax_wrapped_calc_log_exact),
127
- 'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
120
+ 'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
+ 'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
122
+ 'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
123
+ 'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
124
+ 'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
125
+ 'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
126
+ 'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
127
+ 'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
128
128
  }
129
129
 
130
- EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
131
- EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
130
+ EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
131
+ EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
132
132
 
133
- EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
134
- EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
135
- EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
136
- EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
133
+ EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
134
+ EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
135
+ EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
136
+ EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
137
137
 
138
138
  def __init__(self, rddl: RDDLLiftedModel,
139
139
  allow_synchronous_state: bool=True,
@@ -65,9 +65,8 @@ def _parse_config_file(path: str):
65
65
  config = configparser.RawConfigParser()
66
66
  config.optionxform = str
67
67
  config.read(path)
68
- args = {k: literal_eval(v)
69
- for section in config.sections()
70
- for (k, v) in config.items(section)}
68
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
69
+ for section in config.sections()}
71
70
  return config, args
72
71
 
73
72
 
@@ -75,9 +74,8 @@ def _parse_config_string(value: str):
75
74
  config = configparser.RawConfigParser()
76
75
  config.optionxform = str
77
76
  config.read_string(value)
78
- args = {k: literal_eval(v)
79
- for section in config.sections()
80
- for (k, v) in config.items(section)}
77
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
78
+ for section in config.sections()}
81
79
  return config, args
82
80
 
83
81
 
@@ -90,9 +88,9 @@ def _getattr_any(packages, item):
90
88
 
91
89
 
92
90
  def _load_config(config, args):
93
- model_args = {k: args[k] for (k, _) in config.items('Model')}
94
- planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
95
- train_args = {k: args[k] for (k, _) in config.items('Training')}
91
+ model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
92
+ planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
93
+ train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
96
94
 
97
95
  # read the model settings
98
96
  logic_name = model_args.get('logic', 'FuzzyLogic')
@@ -1661,7 +1659,7 @@ r"""
1661
1659
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1662
1660
  epochs: int=999999,
1663
1661
  train_seconds: float=120.,
1664
- dashboard: Optional[JaxPlannerDashboard]=None,
1662
+ dashboard: Optional[Any]=None,
1665
1663
  dashboard_id: Optional[str]=None,
1666
1664
  model_params: Optional[Dict[str, Any]]=None,
1667
1665
  policy_hyperparams: Optional[Dict[str, Any]]=None,
@@ -4,6 +4,7 @@ import threading
4
4
  import multiprocessing
5
5
  import os
6
6
  import time
7
+ import traceback
7
8
  from typing import Any, Callable, Dict, Iterable, Optional, Tuple
8
9
  import warnings
9
10
  warnings.filterwarnings("ignore")
@@ -14,6 +15,7 @@ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
14
15
  import jax
15
16
  import numpy as np
16
17
 
18
+ from pyRDDLGym.core.debug.exception import raise_warning
17
19
  from pyRDDLGym.core.env import RDDLEnv
18
20
 
19
21
  from pyRDDLGym_jax.core.planner import (
@@ -64,6 +66,7 @@ class JaxParameterTuning:
64
66
  hyperparams: Hyperparameters,
65
67
  online: bool,
66
68
  eval_trials: int=5,
69
+ rollouts_per_trial: int=1,
67
70
  verbose: bool=True,
68
71
  timeout_tuning: float=np.inf,
69
72
  pool_context: str='spawn',
@@ -87,6 +90,8 @@ class JaxParameterTuning:
87
90
  hyperparameters in general (in seconds)
88
91
  :param eval_trials: how many trials to perform independent training
89
92
  in order to estimate the return for each set of hyper-parameters
93
+ :param rollouts_per_trial: how many rollouts to perform during evaluation
94
+ at the end of each training trial (only applies when online=False)
90
95
  :param verbose: whether to print intermediate results of tuning
91
96
  :param pool_context: context for multiprocessing pool (default "spawn")
92
97
  :param num_workers: how many points to evaluate in parallel
@@ -108,6 +113,7 @@ class JaxParameterTuning:
108
113
  self.hyperparams_dict = hyperparams_dict
109
114
  self.online = online
110
115
  self.eval_trials = eval_trials
116
+ self.rollouts_per_trial = rollouts_per_trial
111
117
  self.verbose = verbose
112
118
 
113
119
  # Bayesian parameters
@@ -154,6 +160,7 @@ class JaxParameterTuning:
154
160
  f' mp_pool_poll_frequency ={self.poll_frequency}\n'
155
161
  f'meta-objective parameters:\n'
156
162
  f' planning_trials_per_iter ={self.eval_trials}\n'
163
+ f' rollouts_per_trial ={self.rollouts_per_trial}\n'
157
164
  f' acquisition_fn ={self.acquisition}')
158
165
 
159
166
  @staticmethod
@@ -200,12 +207,14 @@ class JaxParameterTuning:
200
207
 
201
208
  @staticmethod
202
209
  def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
203
- verbose, viz, queue):
210
+ rollouts_per_trial, verbose, viz, queue):
204
211
  average_reward = 0.0
205
212
  for trial in range(num_trials):
206
213
  key, subkey = jax.random.split(key)
214
+
215
+ # for the dashboard
207
216
  experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
208
- if queue is not None:
217
+ if queue is not None and JaxPlannerDashboard is not None:
209
218
  queue.put((
210
219
  experiment_id,
211
220
  JaxPlannerDashboard.get_planner_info(planner),
@@ -224,7 +233,8 @@ class JaxParameterTuning:
224
233
  policy = JaxOfflineController(
225
234
  planner=planner, key=subkey, tqdm_position=index,
226
235
  params=best_params, train_on_reset=False)
227
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
236
+ total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
237
+ seed=np.array(subkey)[0])['mean']
228
238
 
229
239
  # update average reward
230
240
  if verbose:
@@ -243,8 +253,10 @@ class JaxParameterTuning:
243
253
  average_reward = 0.0
244
254
  for trial in range(num_trials):
245
255
  key, subkey = jax.random.split(key)
256
+
257
+ # for the dashboard
246
258
  experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
247
- if queue is not None:
259
+ if queue is not None and JaxPlannerDashboard is not None:
248
260
  queue.put((
249
261
  experiment_id,
250
262
  JaxPlannerDashboard.get_planner_info(planner),
@@ -304,6 +316,7 @@ class JaxParameterTuning:
304
316
  domain = kwargs['domain']
305
317
  instance = kwargs['instance']
306
318
  num_trials = kwargs['eval_trials']
319
+ rollouts_per_trial = kwargs['rollouts_per_trial']
307
320
  viz = kwargs['viz']
308
321
  verbose = kwargs['verbose']
309
322
 
@@ -332,7 +345,7 @@ class JaxParameterTuning:
332
345
  else:
333
346
  average_reward = JaxParameterTuning.offline_trials(
334
347
  env, planner, train_args, key, iteration, index,
335
- num_trials, verbose, viz, queue
348
+ num_trials, rollouts_per_trial, verbose, viz, queue
336
349
  )
337
350
 
338
351
  pid = os.getpid()
@@ -353,7 +366,7 @@ class JaxParameterTuning:
353
366
  writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
354
367
 
355
368
  # create a dash-board for visualizing experiment runs
356
- if show_dashboard:
369
+ if show_dashboard and JaxPlannerDashboard is not None:
357
370
  dashboard = JaxPlannerDashboard()
358
371
  dashboard.launch()
359
372
 
@@ -365,6 +378,7 @@ class JaxParameterTuning:
365
378
  'domain': self.env.domain_text,
366
379
  'instance': self.env.instance_text,
367
380
  'eval_trials': self.eval_trials,
381
+ 'rollouts_per_trial': self.rollouts_per_trial,
368
382
  'viz': self.env._visualizer,
369
383
  'verbose': self.verbose
370
384
  }
@@ -1405,7 +1405,7 @@ class JaxPlannerDashboard:
1405
1405
  self.test_reward_dist[experiment_id] = callback['reward']
1406
1406
  self.train_state_fluents[experiment_id] = {
1407
1407
  name: np.asarray(callback['train_log']['fluents'][name])
1408
- for name in rddl.state_fluents or name in rddl.observ_fluents
1408
+ for name in rddl.state_fluents
1409
1409
  }
1410
1410
  self.test_state_fluents[experiment_id] = {
1411
1411
  name: np.asarray(callback['fluents'][name])
@@ -0,0 +1,27 @@
1
+ import argparse
2
+
3
+ from pyRDDLGym_jax.examples import run_plan, run_tune
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
7
+ subparsers = parser.add_subparsers(dest="jaxplan", required=True)
8
+
9
+ # planning
10
+ parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
11
+ parser_plan.add_argument('args', nargs=argparse.REMAINDER)
12
+
13
+ # tuning
14
+ parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
15
+ parser_tune.add_argument('args', nargs=argparse.REMAINDER)
16
+
17
+ # dispatch
18
+ args = parser.parse_args()
19
+ if args.jaxplan == "plan":
20
+ run_plan.run_from_args(args.args)
21
+ elif args.jaxplan == "tune":
22
+ run_tune.run_from_args(args.args)
23
+ else:
24
+ parser.print_help()
25
+
26
+ if __name__ == "__main__":
27
+ main()
@@ -12,7 +12,7 @@ The syntax for running this example is:
12
12
  where:
13
13
  <domain> is the name of a domain located in the /Examples directory
14
14
  <instance> is the instance number
15
- <method> is either slp, drp, or replan
15
+ <method> is slp, drp, replan, or a path to a valid .cfg file
16
16
  <episodes> is the optional number of evaluation rollouts
17
17
  '''
18
18
  import os
@@ -32,12 +32,19 @@ def main(domain, instance, method, episodes=1):
32
32
  env = pyRDDLGym.make(domain, instance, vectorized=True)
33
33
 
34
34
  # load the config file with planner settings
35
- abs_path = os.path.dirname(os.path.abspath(__file__))
36
- config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
37
- if not os.path.isfile(config_path):
38
- raise_warning(f'Config file {config_path} was not found, '
39
- f'using default_{method}.cfg.', 'red')
40
- config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
35
+ if method in ['drp', 'slp', 'replan']:
36
+ abs_path = os.path.dirname(os.path.abspath(__file__))
37
+ config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
38
+ if not os.path.isfile(config_path):
39
+ raise_warning(f'Config file {config_path} was not found, '
40
+ f'using default_{method}.cfg.', 'red')
41
+ config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
42
+ elif os.path.isfile(method):
43
+ config_path = method
44
+ else:
45
+ print('method must be slp, drp, replan, or a path to a valid .cfg file.')
46
+ exit(1)
47
+
41
48
  planner_args, _, train_args = load_config(config_path)
42
49
  if 'dashboard' in train_args:
43
50
  train_args['dashboard'].launch()
@@ -54,16 +61,16 @@ def main(domain, instance, method, episodes=1):
54
61
  controller.evaluate(env, episodes=episodes, verbose=True, render=True)
55
62
  env.close()
56
63
 
57
-
58
- if __name__ == "__main__":
59
- args = sys.argv[1:]
64
+
65
+ def run_from_args(args):
60
66
  if len(args) < 3:
61
67
  print('python run_plan.py <domain> <instance> <method> [<episodes>]')
62
68
  exit(1)
63
- if args[2] not in ['drp', 'slp', 'replan']:
64
- print('<method> in [drp, slp, replan]')
65
- exit(1)
66
69
  kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
67
70
  if len(args) >= 4: kwargs['episodes'] = int(args[3])
68
71
  main(**kwargs)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ run_from_args(sys.argv[1:])
69
76
 
@@ -75,8 +75,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
75
75
  env.close()
76
76
 
77
77
 
78
- if __name__ == "__main__":
79
- args = sys.argv[1:]
78
+ def run_from_args(args):
80
79
  if len(args) < 3:
81
80
  print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
82
81
  exit(1)
@@ -88,4 +87,7 @@ if __name__ == "__main__":
88
87
  if len(args) >= 5: kwargs['iters'] = int(args[4])
89
88
  if len(args) >= 6: kwargs['workers'] = int(args[5])
90
89
  main(**kwargs)
91
-
90
+
91
+
92
+ if __name__ == "__main__":
93
+ run_from_args(sys.argv[1:])
@@ -1,17 +1,21 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.0
3
+ Version: 1.2
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
7
  Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
8
  License: MIT License
9
- Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Natural Language :: English
13
13
  Classifier: Operating System :: OS Independent
14
14
  Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
15
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
20
  Requires-Python: >=3.9
17
21
  Description-Content-Type: text/markdown
@@ -28,10 +32,31 @@ Requires-Dist: rddlrepository>=2.0; extra == "extra"
28
32
  Provides-Extra: dashboard
29
33
  Requires-Dist: dash>=2.18.0; extra == "dashboard"
30
34
  Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
35
+ Dynamic: author
36
+ Dynamic: author-email
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: license
42
+ Dynamic: provides-extra
43
+ Dynamic: requires-dist
44
+ Dynamic: requires-python
45
+ Dynamic: summary
31
46
 
32
47
  # pyRDDLGym-jax
33
48
 
34
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
49
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
50
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
51
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
52
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
53
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
54
+
55
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
56
+
57
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
58
+
59
+ Purpose:
35
60
 
36
61
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
62
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -56,17 +81,6 @@ and was moved to the individual logic components which have their own unique wei
56
81
  > [!NOTE]
57
82
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
58
83
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
59
-
60
- ## Contents
61
-
62
- - [Installation](#installation)
63
- - [Running from the Command Line](#running-from-the-command-line)
64
- - [Running from Another Python Application](#running-from-another-python-application)
65
- - [Configuring the Planner](#configuring-the-planner)
66
- - [JaxPlan Dashboard](#jaxplan-dashboard)
67
- - [Tuning the Planner](#tuning-the-planner)
68
- - [Simulation](#simulation)
69
- - [Citing JaxPlan](#citing-jaxplan)
70
84
 
71
85
  ## Installation
72
86
 
@@ -96,27 +110,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
96
110
 
97
111
  ## Running from the Command Line
98
112
 
99
- A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
113
+ A basic run script is provided to train JaxPlan on any RDDL problem:
100
114
 
101
115
  ```shell
102
- python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
116
+ jaxplan plan <domain> <instance> <method> <episodes>
103
117
  ```
104
118
 
105
119
  where:
106
120
  - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
107
121
  - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
108
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
122
+ - ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
109
123
  - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
110
124
 
111
- The ``method`` parameter supports three possible modes:
125
+ The ``method`` parameter supports four possible modes:
112
126
  - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
113
127
  - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
114
- - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
128
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
129
+ - any other argument is interpreted as a file path to a valid configuration file.
115
130
 
116
- For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
131
+ For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
117
132
 
118
133
  ```shell
119
- python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
134
+ jaxplan plan Quadcopter 1 slp
120
135
  ```
121
136
 
122
137
  ## Running from Another Python Application
@@ -198,7 +213,7 @@ controller = JaxOfflineController(planner, **train_args)
198
213
  ...
199
214
  ```
200
215
 
201
- ### JaxPlan Dashboard
216
+ ## JaxPlan Dashboard
202
217
 
203
218
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
204
219
  and visualization of the policy or model, and other useful debugging features.
@@ -218,7 +233,7 @@ dashboard=True
218
233
 
219
234
  More documentation about this and other new features will be coming soon.
220
235
 
221
- ### Tuning the Planner
236
+ ## Tuning the Planner
222
237
 
223
238
  It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
224
239
  To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -281,7 +296,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
281
296
  A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
282
297
 
283
298
  ```shell
284
- python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
299
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
285
300
  ```
286
301
 
287
302
  where:
@@ -2,9 +2,11 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  pyRDDLGym_jax/__init__.py
5
+ pyRDDLGym_jax/entry_point.py
5
6
  pyRDDLGym_jax.egg-info/PKG-INFO
6
7
  pyRDDLGym_jax.egg-info/SOURCES.txt
7
8
  pyRDDLGym_jax.egg-info/dependency_links.txt
9
+ pyRDDLGym_jax.egg-info/entry_points.txt
8
10
  pyRDDLGym_jax.egg-info/requires.txt
9
11
  pyRDDLGym_jax.egg-info/top_level.txt
10
12
  pyRDDLGym_jax/core/__init__.py
@@ -14,6 +16,8 @@ pyRDDLGym_jax/core/planner.py
14
16
  pyRDDLGym_jax/core/simulator.py
15
17
  pyRDDLGym_jax/core/tuning.py
16
18
  pyRDDLGym_jax/core/visualization.py
19
+ pyRDDLGym_jax/core/assets/__init__.py
20
+ pyRDDLGym_jax/core/assets/favicon.ico
17
21
  pyRDDLGym_jax/examples/__init__.py
18
22
  pyRDDLGym_jax/examples/run_gradient.py
19
23
  pyRDDLGym_jax/examples/run_gym.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ jaxplan = pyRDDLGym_jax.entry_point:main
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='1.0',
22
+ version='1.2',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -41,15 +41,22 @@ setup(
41
41
  'dashboard': ['dash>=2.18.0', 'dash-bootstrap-components>=1.6.0']
42
42
  },
43
43
  python_requires=">=3.9",
44
- package_data={'': ['*.cfg']},
44
+ package_data={'': ['*.cfg', '*.ico']},
45
45
  include_package_data=True,
46
+ entry_points={
47
+ 'console_scripts': [ 'jaxplan=pyRDDLGym_jax.entry_point:main'],
48
+ },
46
49
  classifiers=[
47
- "Development Status :: 3 - Alpha",
50
+ "Development Status :: 5 - Production/Stable",
48
51
  "Intended Audience :: Science/Research",
49
52
  "License :: OSI Approved :: MIT License",
50
53
  "Natural Language :: English",
51
54
  "Operating System :: OS Independent",
52
55
  "Programming Language :: Python :: 3",
56
+ "Programming Language :: Python :: 3.9",
57
+ "Programming Language :: Python :: 3.10",
58
+ "Programming Language :: Python :: 3.11",
59
+ "Programming Language :: Python :: 3.12",
53
60
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
54
61
  ],
55
62
  )
@@ -1 +0,0 @@
1
- __version__ = '1.0'
File without changes
File without changes