pyRDDLGym-jax 1.0__tar.gz → 1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/PKG-INFO +40 -25
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/README.md +22 -22
- pyrddlgym_jax-1.2/pyRDDLGym_jax/__init__.py +1 -0
- pyrddlgym_jax-1.2/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/compiler.py +60 -60
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/planner.py +8 -10
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/tuning.py +20 -6
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/visualization.py +1 -1
- pyrddlgym_jax-1.2/pyRDDLGym_jax/entry_point.py +27 -0
- pyrddlgym_jax-1.2/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_plan.py +20 -13
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_tune.py +5 -3
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/PKG-INFO +40 -25
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/SOURCES.txt +4 -0
- pyrddlgym_jax-1.2/pyRDDLGym_jax.egg-info/entry_points.txt +2 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/setup.py +10 -3
- pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +0 -1
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/LICENSE +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-1.0/pyRDDLGym_jax/examples → pyrddlgym_jax-1.2/pyRDDLGym_jax/core/assets}/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/logic.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/core/simulator.py +0 -0
- {pyrddlgym_jax-1.0/pyRDDLGym_jax/examples/configs → pyrddlgym_jax-1.2/pyRDDLGym_jax/examples}/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/setup.cfg +0 -0
|
@@ -1,17 +1,21 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
7
|
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
8
|
License: MIT License
|
|
9
|
-
Classifier: Development Status ::
|
|
9
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
12
|
Classifier: Natural Language :: English
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
20
|
Requires-Python: >=3.9
|
|
17
21
|
Description-Content-Type: text/markdown
|
|
@@ -28,10 +32,31 @@ Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
|
28
32
|
Provides-Extra: dashboard
|
|
29
33
|
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
34
|
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: provides-extra
|
|
43
|
+
Dynamic: requires-dist
|
|
44
|
+
Dynamic: requires-python
|
|
45
|
+
Dynamic: summary
|
|
31
46
|
|
|
32
47
|
# pyRDDLGym-jax
|
|
33
48
|
|
|
34
|
-
|
|
49
|
+

|
|
50
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
51
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
52
|
+

|
|
53
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
54
|
+
|
|
55
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
56
|
+
|
|
57
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
58
|
+
|
|
59
|
+
Purpose:
|
|
35
60
|
|
|
36
61
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
62
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -56,17 +81,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
56
81
|
> [!NOTE]
|
|
57
82
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
58
83
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
59
|
-
|
|
60
|
-
## Contents
|
|
61
|
-
|
|
62
|
-
- [Installation](#installation)
|
|
63
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
64
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
65
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
68
|
-
- [Simulation](#simulation)
|
|
69
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
70
84
|
|
|
71
85
|
## Installation
|
|
72
86
|
|
|
@@ -96,27 +110,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
96
110
|
|
|
97
111
|
## Running from the Command Line
|
|
98
112
|
|
|
99
|
-
A basic run script is provided to
|
|
113
|
+
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
100
114
|
|
|
101
115
|
```shell
|
|
102
|
-
|
|
116
|
+
jaxplan plan <domain> <instance> <method> <episodes>
|
|
103
117
|
```
|
|
104
118
|
|
|
105
119
|
where:
|
|
106
120
|
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
107
121
|
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
108
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
122
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
|
|
109
123
|
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
110
124
|
|
|
111
|
-
The ``method`` parameter supports
|
|
125
|
+
The ``method`` parameter supports four possible modes:
|
|
112
126
|
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
113
127
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
114
|
-
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
128
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
129
|
+
- any other argument is interpreted as a file path to a valid configuration file.
|
|
115
130
|
|
|
116
|
-
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
131
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
|
|
117
132
|
|
|
118
133
|
```shell
|
|
119
|
-
|
|
134
|
+
jaxplan plan Quadcopter 1 slp
|
|
120
135
|
```
|
|
121
136
|
|
|
122
137
|
## Running from Another Python Application
|
|
@@ -198,7 +213,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
198
213
|
...
|
|
199
214
|
```
|
|
200
215
|
|
|
201
|
-
|
|
216
|
+
## JaxPlan Dashboard
|
|
202
217
|
|
|
203
218
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
204
219
|
and visualization of the policy or model, and other useful debugging features.
|
|
@@ -218,7 +233,7 @@ dashboard=True
|
|
|
218
233
|
|
|
219
234
|
More documentation about this and other new features will be coming soon.
|
|
220
235
|
|
|
221
|
-
|
|
236
|
+
## Tuning the Planner
|
|
222
237
|
|
|
223
238
|
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
224
239
|
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -281,7 +296,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
|
|
|
281
296
|
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
282
297
|
|
|
283
298
|
```shell
|
|
284
|
-
|
|
299
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
285
300
|
```
|
|
286
301
|
|
|
287
302
|
where:
|
|
@@ -1,6 +1,16 @@
|
|
|
1
1
|
# pyRDDLGym-jax
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+

|
|
4
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
5
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
6
|
+

|
|
7
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
8
|
+
|
|
9
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
10
|
+
|
|
11
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
12
|
+
|
|
13
|
+
Purpose:
|
|
4
14
|
|
|
5
15
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
6
16
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -25,17 +35,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
25
35
|
> [!NOTE]
|
|
26
36
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
27
37
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
28
|
-
|
|
29
|
-
## Contents
|
|
30
|
-
|
|
31
|
-
- [Installation](#installation)
|
|
32
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
33
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
34
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
35
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
36
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
37
|
-
- [Simulation](#simulation)
|
|
38
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
39
38
|
|
|
40
39
|
## Installation
|
|
41
40
|
|
|
@@ -65,27 +64,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
65
64
|
|
|
66
65
|
## Running from the Command Line
|
|
67
66
|
|
|
68
|
-
A basic run script is provided to
|
|
67
|
+
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
69
68
|
|
|
70
69
|
```shell
|
|
71
|
-
|
|
70
|
+
jaxplan plan <domain> <instance> <method> <episodes>
|
|
72
71
|
```
|
|
73
72
|
|
|
74
73
|
where:
|
|
75
74
|
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
76
75
|
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
77
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
76
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
|
|
78
77
|
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
79
78
|
|
|
80
|
-
The ``method`` parameter supports
|
|
79
|
+
The ``method`` parameter supports four possible modes:
|
|
81
80
|
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
82
81
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
83
|
-
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
82
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
83
|
+
- any other argument is interpreted as a file path to a valid configuration file.
|
|
84
84
|
|
|
85
|
-
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
85
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
|
|
86
86
|
|
|
87
87
|
```shell
|
|
88
|
-
|
|
88
|
+
jaxplan plan Quadcopter 1 slp
|
|
89
89
|
```
|
|
90
90
|
|
|
91
91
|
## Running from Another Python Application
|
|
@@ -167,7 +167,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
167
167
|
...
|
|
168
168
|
```
|
|
169
169
|
|
|
170
|
-
|
|
170
|
+
## JaxPlan Dashboard
|
|
171
171
|
|
|
172
172
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
173
173
|
and visualization of the policy or model, and other useful debugging features.
|
|
@@ -187,7 +187,7 @@ dashboard=True
|
|
|
187
187
|
|
|
188
188
|
More documentation about this and other new features will be coming soon.
|
|
189
189
|
|
|
190
|
-
|
|
190
|
+
## Tuning the Planner
|
|
191
191
|
|
|
192
192
|
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
193
193
|
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -250,7 +250,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
|
|
|
250
250
|
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
251
251
|
|
|
252
252
|
```shell
|
|
253
|
-
|
|
253
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
254
254
|
```
|
|
255
255
|
|
|
256
256
|
where:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '1.2'
|
|
Binary file
|
|
@@ -51,65 +51,65 @@ class JaxRDDLCompiler:
|
|
|
51
51
|
return func
|
|
52
52
|
return exact_func
|
|
53
53
|
|
|
54
|
-
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
|
|
54
|
+
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
|
|
55
55
|
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
56
|
-
'+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
-
'-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
-
'*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
-
'/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
56
|
+
'+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
+
'-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
+
'*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
+
'/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
|
|
60
60
|
}
|
|
61
61
|
|
|
62
62
|
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
63
|
-
'>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
-
'<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
-
'<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
-
'>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
-
'==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
-
'~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
63
|
+
'>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
+
'<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
+
'<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
+
'>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
+
'==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
+
'~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
69
69
|
}
|
|
70
70
|
|
|
71
|
-
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
71
|
+
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
72
72
|
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
73
|
-
'^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
-
'&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
-
'|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
-
'~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
-
'=>': wrap_logic(ExactLogic.exact_binary_implies),
|
|
78
|
-
'<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
|
|
73
|
+
'^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
+
'&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
+
'|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
+
'~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
+
'=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
|
|
78
|
+
'<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
|
|
79
79
|
}
|
|
80
80
|
|
|
81
81
|
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
82
|
-
'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
-
'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
-
'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
-
'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
-
'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
-
'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
-
'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
-
'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
-
'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
|
|
82
|
+
'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
+
'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
+
'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
+
'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
+
'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
+
'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
+
'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
+
'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
+
'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
|
|
91
91
|
}
|
|
92
92
|
|
|
93
93
|
EXACT_RDDL_TO_JAX_UNARY = {
|
|
94
|
-
'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
-
'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
-
'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
-
'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
-
'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
-
'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
-
'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
-
'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
-
'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
-
'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
-
'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
-
'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
-
'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
-
'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
-
'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
-
'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
-
'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
-
'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
-
'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
94
|
+
'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
+
'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
+
'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
+
'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
+
'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
+
'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
+
'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
+
'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
+
'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
+
'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
+
'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
+
'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
+
'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
+
'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
+
'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
+
'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
+
'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
+
'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
+
'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
113
113
|
}
|
|
114
114
|
|
|
115
115
|
@staticmethod
|
|
@@ -117,23 +117,23 @@ class JaxRDDLCompiler:
|
|
|
117
117
|
return jnp.log(x) / jnp.log(y), params
|
|
118
118
|
|
|
119
119
|
EXACT_RDDL_TO_JAX_BINARY = {
|
|
120
|
-
'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
-
'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
-
'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
-
'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
-
'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
-
'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
-
'log': wrap_logic(_jax_wrapped_calc_log_exact),
|
|
127
|
-
'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
120
|
+
'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
+
'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
+
'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
+
'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
+
'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
+
'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
+
'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
|
|
127
|
+
'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
128
128
|
}
|
|
129
129
|
|
|
130
|
-
EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
|
|
131
|
-
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
|
|
130
|
+
EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
|
|
131
|
+
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
|
|
132
132
|
|
|
133
|
-
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
|
|
134
|
-
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
|
|
135
|
-
EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
|
|
136
|
-
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
|
|
133
|
+
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
|
|
134
|
+
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
|
|
135
|
+
EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
|
|
136
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
|
|
137
137
|
|
|
138
138
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
139
139
|
allow_synchronous_state: bool=True,
|
|
@@ -65,9 +65,8 @@ def _parse_config_file(path: str):
|
|
|
65
65
|
config = configparser.RawConfigParser()
|
|
66
66
|
config.optionxform = str
|
|
67
67
|
config.read(path)
|
|
68
|
-
args = {k: literal_eval(v)
|
|
69
|
-
for section in config.sections()
|
|
70
|
-
for (k, v) in config.items(section)}
|
|
68
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
69
|
+
for section in config.sections()}
|
|
71
70
|
return config, args
|
|
72
71
|
|
|
73
72
|
|
|
@@ -75,9 +74,8 @@ def _parse_config_string(value: str):
|
|
|
75
74
|
config = configparser.RawConfigParser()
|
|
76
75
|
config.optionxform = str
|
|
77
76
|
config.read_string(value)
|
|
78
|
-
args = {k: literal_eval(v)
|
|
79
|
-
for section in config.sections()
|
|
80
|
-
for (k, v) in config.items(section)}
|
|
77
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
78
|
+
for section in config.sections()}
|
|
81
79
|
return config, args
|
|
82
80
|
|
|
83
81
|
|
|
@@ -90,9 +88,9 @@ def _getattr_any(packages, item):
|
|
|
90
88
|
|
|
91
89
|
|
|
92
90
|
def _load_config(config, args):
|
|
93
|
-
model_args = {k: args[k] for (k, _) in config.items('Model')}
|
|
94
|
-
planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
|
|
95
|
-
train_args = {k: args[k] for (k, _) in config.items('Training')}
|
|
91
|
+
model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
|
|
92
|
+
planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
|
|
93
|
+
train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
|
|
96
94
|
|
|
97
95
|
# read the model settings
|
|
98
96
|
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
@@ -1661,7 +1659,7 @@ r"""
|
|
|
1661
1659
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1662
1660
|
epochs: int=999999,
|
|
1663
1661
|
train_seconds: float=120.,
|
|
1664
|
-
dashboard: Optional[
|
|
1662
|
+
dashboard: Optional[Any]=None,
|
|
1665
1663
|
dashboard_id: Optional[str]=None,
|
|
1666
1664
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1667
1665
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
@@ -4,6 +4,7 @@ import threading
|
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
+
import traceback
|
|
7
8
|
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
8
9
|
import warnings
|
|
9
10
|
warnings.filterwarnings("ignore")
|
|
@@ -14,6 +15,7 @@ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
|
|
|
14
15
|
import jax
|
|
15
16
|
import numpy as np
|
|
16
17
|
|
|
18
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
17
19
|
from pyRDDLGym.core.env import RDDLEnv
|
|
18
20
|
|
|
19
21
|
from pyRDDLGym_jax.core.planner import (
|
|
@@ -64,6 +66,7 @@ class JaxParameterTuning:
|
|
|
64
66
|
hyperparams: Hyperparameters,
|
|
65
67
|
online: bool,
|
|
66
68
|
eval_trials: int=5,
|
|
69
|
+
rollouts_per_trial: int=1,
|
|
67
70
|
verbose: bool=True,
|
|
68
71
|
timeout_tuning: float=np.inf,
|
|
69
72
|
pool_context: str='spawn',
|
|
@@ -87,6 +90,8 @@ class JaxParameterTuning:
|
|
|
87
90
|
hyperparameters in general (in seconds)
|
|
88
91
|
:param eval_trials: how many trials to perform independent training
|
|
89
92
|
in order to estimate the return for each set of hyper-parameters
|
|
93
|
+
:param rollouts_per_trial: how many rollouts to perform during evaluation
|
|
94
|
+
at the end of each training trial (only applies when online=False)
|
|
90
95
|
:param verbose: whether to print intermediate results of tuning
|
|
91
96
|
:param pool_context: context for multiprocessing pool (default "spawn")
|
|
92
97
|
:param num_workers: how many points to evaluate in parallel
|
|
@@ -108,6 +113,7 @@ class JaxParameterTuning:
|
|
|
108
113
|
self.hyperparams_dict = hyperparams_dict
|
|
109
114
|
self.online = online
|
|
110
115
|
self.eval_trials = eval_trials
|
|
116
|
+
self.rollouts_per_trial = rollouts_per_trial
|
|
111
117
|
self.verbose = verbose
|
|
112
118
|
|
|
113
119
|
# Bayesian parameters
|
|
@@ -154,6 +160,7 @@ class JaxParameterTuning:
|
|
|
154
160
|
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
155
161
|
f'meta-objective parameters:\n'
|
|
156
162
|
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
163
|
+
f' rollouts_per_trial ={self.rollouts_per_trial}\n'
|
|
157
164
|
f' acquisition_fn ={self.acquisition}')
|
|
158
165
|
|
|
159
166
|
@staticmethod
|
|
@@ -200,12 +207,14 @@ class JaxParameterTuning:
|
|
|
200
207
|
|
|
201
208
|
@staticmethod
|
|
202
209
|
def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
203
|
-
verbose, viz, queue):
|
|
210
|
+
rollouts_per_trial, verbose, viz, queue):
|
|
204
211
|
average_reward = 0.0
|
|
205
212
|
for trial in range(num_trials):
|
|
206
213
|
key, subkey = jax.random.split(key)
|
|
214
|
+
|
|
215
|
+
# for the dashboard
|
|
207
216
|
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
208
|
-
if queue is not None:
|
|
217
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
209
218
|
queue.put((
|
|
210
219
|
experiment_id,
|
|
211
220
|
JaxPlannerDashboard.get_planner_info(planner),
|
|
@@ -224,7 +233,8 @@ class JaxParameterTuning:
|
|
|
224
233
|
policy = JaxOfflineController(
|
|
225
234
|
planner=planner, key=subkey, tqdm_position=index,
|
|
226
235
|
params=best_params, train_on_reset=False)
|
|
227
|
-
total_reward = policy.evaluate(env,
|
|
236
|
+
total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
|
|
237
|
+
seed=np.array(subkey)[0])['mean']
|
|
228
238
|
|
|
229
239
|
# update average reward
|
|
230
240
|
if verbose:
|
|
@@ -243,8 +253,10 @@ class JaxParameterTuning:
|
|
|
243
253
|
average_reward = 0.0
|
|
244
254
|
for trial in range(num_trials):
|
|
245
255
|
key, subkey = jax.random.split(key)
|
|
256
|
+
|
|
257
|
+
# for the dashboard
|
|
246
258
|
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
247
|
-
if queue is not None:
|
|
259
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
248
260
|
queue.put((
|
|
249
261
|
experiment_id,
|
|
250
262
|
JaxPlannerDashboard.get_planner_info(planner),
|
|
@@ -304,6 +316,7 @@ class JaxParameterTuning:
|
|
|
304
316
|
domain = kwargs['domain']
|
|
305
317
|
instance = kwargs['instance']
|
|
306
318
|
num_trials = kwargs['eval_trials']
|
|
319
|
+
rollouts_per_trial = kwargs['rollouts_per_trial']
|
|
307
320
|
viz = kwargs['viz']
|
|
308
321
|
verbose = kwargs['verbose']
|
|
309
322
|
|
|
@@ -332,7 +345,7 @@ class JaxParameterTuning:
|
|
|
332
345
|
else:
|
|
333
346
|
average_reward = JaxParameterTuning.offline_trials(
|
|
334
347
|
env, planner, train_args, key, iteration, index,
|
|
335
|
-
num_trials, verbose, viz, queue
|
|
348
|
+
num_trials, rollouts_per_trial, verbose, viz, queue
|
|
336
349
|
)
|
|
337
350
|
|
|
338
351
|
pid = os.getpid()
|
|
@@ -353,7 +366,7 @@ class JaxParameterTuning:
|
|
|
353
366
|
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
354
367
|
|
|
355
368
|
# create a dash-board for visualizing experiment runs
|
|
356
|
-
if show_dashboard:
|
|
369
|
+
if show_dashboard and JaxPlannerDashboard is not None:
|
|
357
370
|
dashboard = JaxPlannerDashboard()
|
|
358
371
|
dashboard.launch()
|
|
359
372
|
|
|
@@ -365,6 +378,7 @@ class JaxParameterTuning:
|
|
|
365
378
|
'domain': self.env.domain_text,
|
|
366
379
|
'instance': self.env.instance_text,
|
|
367
380
|
'eval_trials': self.eval_trials,
|
|
381
|
+
'rollouts_per_trial': self.rollouts_per_trial,
|
|
368
382
|
'viz': self.env._visualizer,
|
|
369
383
|
'verbose': self.verbose
|
|
370
384
|
}
|
|
@@ -1405,7 +1405,7 @@ class JaxPlannerDashboard:
|
|
|
1405
1405
|
self.test_reward_dist[experiment_id] = callback['reward']
|
|
1406
1406
|
self.train_state_fluents[experiment_id] = {
|
|
1407
1407
|
name: np.asarray(callback['train_log']['fluents'][name])
|
|
1408
|
-
for name in rddl.state_fluents
|
|
1408
|
+
for name in rddl.state_fluents
|
|
1409
1409
|
}
|
|
1410
1410
|
self.test_state_fluents[experiment_id] = {
|
|
1411
1411
|
name: np.asarray(callback['fluents'][name])
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
|
|
7
|
+
subparsers = parser.add_subparsers(dest="jaxplan", required=True)
|
|
8
|
+
|
|
9
|
+
# planning
|
|
10
|
+
parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
|
|
11
|
+
parser_plan.add_argument('args', nargs=argparse.REMAINDER)
|
|
12
|
+
|
|
13
|
+
# tuning
|
|
14
|
+
parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
|
|
15
|
+
parser_tune.add_argument('args', nargs=argparse.REMAINDER)
|
|
16
|
+
|
|
17
|
+
# dispatch
|
|
18
|
+
args = parser.parse_args()
|
|
19
|
+
if args.jaxplan == "plan":
|
|
20
|
+
run_plan.run_from_args(args.args)
|
|
21
|
+
elif args.jaxplan == "tune":
|
|
22
|
+
run_tune.run_from_args(args.args)
|
|
23
|
+
else:
|
|
24
|
+
parser.print_help()
|
|
25
|
+
|
|
26
|
+
if __name__ == "__main__":
|
|
27
|
+
main()
|
|
File without changes
|
|
@@ -12,7 +12,7 @@ The syntax for running this example is:
|
|
|
12
12
|
where:
|
|
13
13
|
<domain> is the name of a domain located in the /Examples directory
|
|
14
14
|
<instance> is the instance number
|
|
15
|
-
<method> is
|
|
15
|
+
<method> is slp, drp, replan, or a path to a valid .cfg file
|
|
16
16
|
<episodes> is the optional number of evaluation rollouts
|
|
17
17
|
'''
|
|
18
18
|
import os
|
|
@@ -32,12 +32,19 @@ def main(domain, instance, method, episodes=1):
|
|
|
32
32
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
33
33
|
|
|
34
34
|
# load the config file with planner settings
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
35
|
+
if method in ['drp', 'slp', 'replan']:
|
|
36
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
37
|
+
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
38
|
+
if not os.path.isfile(config_path):
|
|
39
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
40
|
+
f'using default_{method}.cfg.', 'red')
|
|
41
|
+
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
42
|
+
elif os.path.isfile(method):
|
|
43
|
+
config_path = method
|
|
44
|
+
else:
|
|
45
|
+
print('method must be slp, drp, replan, or a path to a valid .cfg file.')
|
|
46
|
+
exit(1)
|
|
47
|
+
|
|
41
48
|
planner_args, _, train_args = load_config(config_path)
|
|
42
49
|
if 'dashboard' in train_args:
|
|
43
50
|
train_args['dashboard'].launch()
|
|
@@ -54,16 +61,16 @@ def main(domain, instance, method, episodes=1):
|
|
|
54
61
|
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
55
62
|
env.close()
|
|
56
63
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
args = sys.argv[1:]
|
|
64
|
+
|
|
65
|
+
def run_from_args(args):
|
|
60
66
|
if len(args) < 3:
|
|
61
67
|
print('python run_plan.py <domain> <instance> <method> [<episodes>]')
|
|
62
68
|
exit(1)
|
|
63
|
-
if args[2] not in ['drp', 'slp', 'replan']:
|
|
64
|
-
print('<method> in [drp, slp, replan]')
|
|
65
|
-
exit(1)
|
|
66
69
|
kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
|
|
67
70
|
if len(args) >= 4: kwargs['episodes'] = int(args[3])
|
|
68
71
|
main(**kwargs)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
run_from_args(sys.argv[1:])
|
|
69
76
|
|
|
@@ -75,8 +75,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
75
75
|
env.close()
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
|
|
79
|
-
args = sys.argv[1:]
|
|
78
|
+
def run_from_args(args):
|
|
80
79
|
if len(args) < 3:
|
|
81
80
|
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
|
|
82
81
|
exit(1)
|
|
@@ -88,4 +87,7 @@ if __name__ == "__main__":
|
|
|
88
87
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
89
88
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
90
89
|
main(**kwargs)
|
|
91
|
-
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
run_from_args(sys.argv[1:])
|
|
@@ -1,17 +1,21 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
7
|
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
8
|
License: MIT License
|
|
9
|
-
Classifier: Development Status ::
|
|
9
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
12
|
Classifier: Natural Language :: English
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
20
|
Requires-Python: >=3.9
|
|
17
21
|
Description-Content-Type: text/markdown
|
|
@@ -28,10 +32,31 @@ Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
|
28
32
|
Provides-Extra: dashboard
|
|
29
33
|
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
34
|
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: provides-extra
|
|
43
|
+
Dynamic: requires-dist
|
|
44
|
+
Dynamic: requires-python
|
|
45
|
+
Dynamic: summary
|
|
31
46
|
|
|
32
47
|
# pyRDDLGym-jax
|
|
33
48
|
|
|
34
|
-
|
|
49
|
+

|
|
50
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
51
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
52
|
+

|
|
53
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
54
|
+
|
|
55
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
56
|
+
|
|
57
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
58
|
+
|
|
59
|
+
Purpose:
|
|
35
60
|
|
|
36
61
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
62
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -56,17 +81,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
56
81
|
> [!NOTE]
|
|
57
82
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
58
83
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
59
|
-
|
|
60
|
-
## Contents
|
|
61
|
-
|
|
62
|
-
- [Installation](#installation)
|
|
63
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
64
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
65
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
68
|
-
- [Simulation](#simulation)
|
|
69
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
70
84
|
|
|
71
85
|
## Installation
|
|
72
86
|
|
|
@@ -96,27 +110,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
96
110
|
|
|
97
111
|
## Running from the Command Line
|
|
98
112
|
|
|
99
|
-
A basic run script is provided to
|
|
113
|
+
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
100
114
|
|
|
101
115
|
```shell
|
|
102
|
-
|
|
116
|
+
jaxplan plan <domain> <instance> <method> <episodes>
|
|
103
117
|
```
|
|
104
118
|
|
|
105
119
|
where:
|
|
106
120
|
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
107
121
|
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
108
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
122
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
|
|
109
123
|
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
110
124
|
|
|
111
|
-
The ``method`` parameter supports
|
|
125
|
+
The ``method`` parameter supports four possible modes:
|
|
112
126
|
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
113
127
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
114
|
-
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
128
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
129
|
+
- any other argument is interpreted as a file path to a valid configuration file.
|
|
115
130
|
|
|
116
|
-
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
131
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
|
|
117
132
|
|
|
118
133
|
```shell
|
|
119
|
-
|
|
134
|
+
jaxplan plan Quadcopter 1 slp
|
|
120
135
|
```
|
|
121
136
|
|
|
122
137
|
## Running from Another Python Application
|
|
@@ -198,7 +213,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
198
213
|
...
|
|
199
214
|
```
|
|
200
215
|
|
|
201
|
-
|
|
216
|
+
## JaxPlan Dashboard
|
|
202
217
|
|
|
203
218
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
204
219
|
and visualization of the policy or model, and other useful debugging features.
|
|
@@ -218,7 +233,7 @@ dashboard=True
|
|
|
218
233
|
|
|
219
234
|
More documentation about this and other new features will be coming soon.
|
|
220
235
|
|
|
221
|
-
|
|
236
|
+
## Tuning the Planner
|
|
222
237
|
|
|
223
238
|
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
224
239
|
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -281,7 +296,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
|
|
|
281
296
|
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
282
297
|
|
|
283
298
|
```shell
|
|
284
|
-
|
|
299
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
285
300
|
```
|
|
286
301
|
|
|
287
302
|
where:
|
|
@@ -2,9 +2,11 @@ LICENSE
|
|
|
2
2
|
README.md
|
|
3
3
|
setup.py
|
|
4
4
|
pyRDDLGym_jax/__init__.py
|
|
5
|
+
pyRDDLGym_jax/entry_point.py
|
|
5
6
|
pyRDDLGym_jax.egg-info/PKG-INFO
|
|
6
7
|
pyRDDLGym_jax.egg-info/SOURCES.txt
|
|
7
8
|
pyRDDLGym_jax.egg-info/dependency_links.txt
|
|
9
|
+
pyRDDLGym_jax.egg-info/entry_points.txt
|
|
8
10
|
pyRDDLGym_jax.egg-info/requires.txt
|
|
9
11
|
pyRDDLGym_jax.egg-info/top_level.txt
|
|
10
12
|
pyRDDLGym_jax/core/__init__.py
|
|
@@ -14,6 +16,8 @@ pyRDDLGym_jax/core/planner.py
|
|
|
14
16
|
pyRDDLGym_jax/core/simulator.py
|
|
15
17
|
pyRDDLGym_jax/core/tuning.py
|
|
16
18
|
pyRDDLGym_jax/core/visualization.py
|
|
19
|
+
pyRDDLGym_jax/core/assets/__init__.py
|
|
20
|
+
pyRDDLGym_jax/core/assets/favicon.ico
|
|
17
21
|
pyRDDLGym_jax/examples/__init__.py
|
|
18
22
|
pyRDDLGym_jax/examples/run_gradient.py
|
|
19
23
|
pyRDDLGym_jax/examples/run_gym.py
|
|
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
|
|
|
19
19
|
|
|
20
20
|
setup(
|
|
21
21
|
name='pyRDDLGym-jax',
|
|
22
|
-
version='1.
|
|
22
|
+
version='1.2',
|
|
23
23
|
author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
|
|
24
24
|
author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
|
|
25
25
|
description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
|
|
@@ -41,15 +41,22 @@ setup(
|
|
|
41
41
|
'dashboard': ['dash>=2.18.0', 'dash-bootstrap-components>=1.6.0']
|
|
42
42
|
},
|
|
43
43
|
python_requires=">=3.9",
|
|
44
|
-
package_data={'': ['*.cfg']},
|
|
44
|
+
package_data={'': ['*.cfg', '*.ico']},
|
|
45
45
|
include_package_data=True,
|
|
46
|
+
entry_points={
|
|
47
|
+
'console_scripts': [ 'jaxplan=pyRDDLGym_jax.entry_point:main'],
|
|
48
|
+
},
|
|
46
49
|
classifiers=[
|
|
47
|
-
"Development Status ::
|
|
50
|
+
"Development Status :: 5 - Production/Stable",
|
|
48
51
|
"Intended Audience :: Science/Research",
|
|
49
52
|
"License :: OSI Approved :: MIT License",
|
|
50
53
|
"Natural Language :: English",
|
|
51
54
|
"Operating System :: OS Independent",
|
|
52
55
|
"Programming Language :: Python :: 3",
|
|
56
|
+
"Programming Language :: Python :: 3.9",
|
|
57
|
+
"Programming Language :: Python :: 3.10",
|
|
58
|
+
"Programming Language :: Python :: 3.11",
|
|
59
|
+
"Programming Language :: Python :: 3.12",
|
|
53
60
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
54
61
|
],
|
|
55
62
|
)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '1.0'
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0/pyRDDLGym_jax/examples → pyrddlgym_jax-1.2/pyRDDLGym_jax/core/assets}/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|