pyRDDLGym-jax 1.1__py3-none-any.whl → 1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/assets/__init__.py +0 -0
- pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- pyRDDLGym_jax/entry_point.py +27 -0
- pyRDDLGym_jax/examples/run_plan.py +20 -13
- pyRDDLGym_jax/examples/run_tune.py +5 -3
- {pyRDDLGym_jax-1.1.dist-info → pyRDDLGym_jax-1.2.dist-info}/METADATA +29 -13
- {pyRDDLGym_jax-1.1.dist-info → pyRDDLGym_jax-1.2.dist-info}/RECORD +12 -8
- {pyRDDLGym_jax-1.1.dist-info → pyRDDLGym_jax-1.2.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-1.2.dist-info/entry_points.txt +2 -0
- {pyRDDLGym_jax-1.1.dist-info → pyRDDLGym_jax-1.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.1.dist-info → pyRDDLGym_jax-1.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '1.
|
|
1
|
+
__version__ = '1.2'
|
|
File without changes
|
|
Binary file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
|
|
7
|
+
subparsers = parser.add_subparsers(dest="jaxplan", required=True)
|
|
8
|
+
|
|
9
|
+
# planning
|
|
10
|
+
parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
|
|
11
|
+
parser_plan.add_argument('args', nargs=argparse.REMAINDER)
|
|
12
|
+
|
|
13
|
+
# tuning
|
|
14
|
+
parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
|
|
15
|
+
parser_tune.add_argument('args', nargs=argparse.REMAINDER)
|
|
16
|
+
|
|
17
|
+
# dispatch
|
|
18
|
+
args = parser.parse_args()
|
|
19
|
+
if args.jaxplan == "plan":
|
|
20
|
+
run_plan.run_from_args(args.args)
|
|
21
|
+
elif args.jaxplan == "tune":
|
|
22
|
+
run_tune.run_from_args(args.args)
|
|
23
|
+
else:
|
|
24
|
+
parser.print_help()
|
|
25
|
+
|
|
26
|
+
if __name__ == "__main__":
|
|
27
|
+
main()
|
|
@@ -12,7 +12,7 @@ The syntax for running this example is:
|
|
|
12
12
|
where:
|
|
13
13
|
<domain> is the name of a domain located in the /Examples directory
|
|
14
14
|
<instance> is the instance number
|
|
15
|
-
<method> is
|
|
15
|
+
<method> is slp, drp, replan, or a path to a valid .cfg file
|
|
16
16
|
<episodes> is the optional number of evaluation rollouts
|
|
17
17
|
'''
|
|
18
18
|
import os
|
|
@@ -32,12 +32,19 @@ def main(domain, instance, method, episodes=1):
|
|
|
32
32
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
33
33
|
|
|
34
34
|
# load the config file with planner settings
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
35
|
+
if method in ['drp', 'slp', 'replan']:
|
|
36
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
37
|
+
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
38
|
+
if not os.path.isfile(config_path):
|
|
39
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
40
|
+
f'using default_{method}.cfg.', 'red')
|
|
41
|
+
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
42
|
+
elif os.path.isfile(method):
|
|
43
|
+
config_path = method
|
|
44
|
+
else:
|
|
45
|
+
print('method must be slp, drp, replan, or a path to a valid .cfg file.')
|
|
46
|
+
exit(1)
|
|
47
|
+
|
|
41
48
|
planner_args, _, train_args = load_config(config_path)
|
|
42
49
|
if 'dashboard' in train_args:
|
|
43
50
|
train_args['dashboard'].launch()
|
|
@@ -54,16 +61,16 @@ def main(domain, instance, method, episodes=1):
|
|
|
54
61
|
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
55
62
|
env.close()
|
|
56
63
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
args = sys.argv[1:]
|
|
64
|
+
|
|
65
|
+
def run_from_args(args):
|
|
60
66
|
if len(args) < 3:
|
|
61
67
|
print('python run_plan.py <domain> <instance> <method> [<episodes>]')
|
|
62
68
|
exit(1)
|
|
63
|
-
if args[2] not in ['drp', 'slp', 'replan']:
|
|
64
|
-
print('<method> in [drp, slp, replan]')
|
|
65
|
-
exit(1)
|
|
66
69
|
kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
|
|
67
70
|
if len(args) >= 4: kwargs['episodes'] = int(args[3])
|
|
68
71
|
main(**kwargs)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
run_from_args(sys.argv[1:])
|
|
69
76
|
|
|
@@ -75,8 +75,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
75
75
|
env.close()
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
|
|
79
|
-
args = sys.argv[1:]
|
|
78
|
+
def run_from_args(args):
|
|
80
79
|
if len(args) < 3:
|
|
81
80
|
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
|
|
82
81
|
exit(1)
|
|
@@ -88,4 +87,7 @@ if __name__ == "__main__":
|
|
|
88
87
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
89
88
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
90
89
|
main(**kwargs)
|
|
91
|
-
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
run_from_args(sys.argv[1:])
|
|
@@ -1,17 +1,21 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
7
|
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
8
|
License: MIT License
|
|
9
|
-
Classifier: Development Status ::
|
|
9
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
12
|
Classifier: Natural Language :: English
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
20
|
Requires-Python: >=3.9
|
|
17
21
|
Description-Content-Type: text/markdown
|
|
@@ -28,6 +32,17 @@ Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
|
28
32
|
Provides-Extra: dashboard
|
|
29
33
|
Requires-Dist: dash>=2.18.0; extra == "dashboard"
|
|
30
34
|
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: provides-extra
|
|
43
|
+
Dynamic: requires-dist
|
|
44
|
+
Dynamic: requires-python
|
|
45
|
+
Dynamic: summary
|
|
31
46
|
|
|
32
47
|
# pyRDDLGym-jax
|
|
33
48
|
|
|
@@ -95,27 +110,28 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
95
110
|
|
|
96
111
|
## Running from the Command Line
|
|
97
112
|
|
|
98
|
-
A basic run script is provided to
|
|
113
|
+
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
99
114
|
|
|
100
115
|
```shell
|
|
101
|
-
|
|
116
|
+
jaxplan plan <domain> <instance> <method> <episodes>
|
|
102
117
|
```
|
|
103
118
|
|
|
104
119
|
where:
|
|
105
120
|
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
106
121
|
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
107
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
122
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
|
|
108
123
|
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
109
124
|
|
|
110
|
-
The ``method`` parameter supports
|
|
125
|
+
The ``method`` parameter supports four possible modes:
|
|
111
126
|
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
112
127
|
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
113
|
-
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
128
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step
|
|
129
|
+
- any other argument is interpreted as a file path to a valid configuration file.
|
|
114
130
|
|
|
115
|
-
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:
|
|
131
|
+
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
|
|
116
132
|
|
|
117
133
|
```shell
|
|
118
|
-
|
|
134
|
+
jaxplan plan Quadcopter 1 slp
|
|
119
135
|
```
|
|
120
136
|
|
|
121
137
|
## Running from Another Python Application
|
|
@@ -197,7 +213,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
197
213
|
...
|
|
198
214
|
```
|
|
199
215
|
|
|
200
|
-
|
|
216
|
+
## JaxPlan Dashboard
|
|
201
217
|
|
|
202
218
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
203
219
|
and visualization of the policy or model, and other useful debugging features.
|
|
@@ -217,7 +233,7 @@ dashboard=True
|
|
|
217
233
|
|
|
218
234
|
More documentation about this and other new features will be coming soon.
|
|
219
235
|
|
|
220
|
-
|
|
236
|
+
## Tuning the Planner
|
|
221
237
|
|
|
222
238
|
It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
|
|
223
239
|
To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -280,7 +296,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
|
|
|
280
296
|
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
281
297
|
|
|
282
298
|
```shell
|
|
283
|
-
|
|
299
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
284
300
|
```
|
|
285
301
|
|
|
286
302
|
where:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=LTT-ZpL6vrKdC5t0O71pJnk3zMhDf1eXkNmoLoIRupo,19
|
|
2
|
+
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
2
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
4
|
pyRDDLGym_jax/core/compiler.py,sha256=qy1TSivdpuZxWecDl5HEM0PXX45JB7DHzV7uAB8kmbE,88696
|
|
4
5
|
pyRDDLGym_jax/core/logic.py,sha256=iYvLgWyQd_mrkwwoeRWao9NzjmhsObQnPq4DphILw1Q,38425
|
|
@@ -6,12 +7,14 @@ pyRDDLGym_jax/core/planner.py,sha256=oKs9js7xyIc9-bxQFZSQNBw9s1nWQlz4DjENwEgSojY
|
|
|
6
7
|
pyRDDLGym_jax/core/simulator.py,sha256=JpmwfPqYPBfEhmQ04ufBeclZOQ-U1ZiyAtLf1AIwO2M,8462
|
|
7
8
|
pyRDDLGym_jax/core/tuning.py,sha256=LBhoVQZWWhYQj89gpM2B4xVHlYlKDt4psw4Be9cBbSY,23685
|
|
8
9
|
pyRDDLGym_jax/core/visualization.py,sha256=uKhC8z0TeX9BklPNoxSVt0g5pkqhgxrQClQAih78ybY,68292
|
|
10
|
+
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
9
12
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
13
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
11
14
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
12
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
15
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=v2AvwgIa4Ejr626vBOgWFJIQvay3IPKWno02ztIFCYc,2768
|
|
13
16
|
pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
|
|
14
|
-
pyRDDLGym_jax/examples/run_tune.py,sha256=
|
|
17
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=zqrhvLR5PeWJv0NsRxDCzAPmvgPgz_1NrtM1xBy6ndU,3606
|
|
15
18
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
16
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=CI_c-E2Ij2dzVbYFA3sAUEXQBaIDImaEH15HpLqGQRw,370
|
|
17
20
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=T-O4faHYfSMyNNjY2jQ9XPK772szjbC7Enip5AaEO_0,340
|
|
@@ -38,8 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
38
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
39
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
40
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
41
|
-
pyRDDLGym_jax-1.
|
|
42
|
-
pyRDDLGym_jax-1.
|
|
43
|
-
pyRDDLGym_jax-1.
|
|
44
|
-
pyRDDLGym_jax-1.
|
|
45
|
-
pyRDDLGym_jax-1.
|
|
44
|
+
pyRDDLGym_jax-1.2.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyRDDLGym_jax-1.2.dist-info/METADATA,sha256=oWVOtC5AvAm2Xvdd507gXr3b6_aZLaH7LnOj6hADdgQ,15090
|
|
46
|
+
pyRDDLGym_jax-1.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
47
|
+
pyRDDLGym_jax-1.2.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyRDDLGym_jax-1.2.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyRDDLGym_jax-1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|