pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +90 -68
- pyRDDLGym_jax/core/logic.py +188 -46
- pyRDDLGym_jax/core/planner.py +411 -195
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +13 -10
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_slp.cfg +1 -1
- pyRDDLGym_jax/examples/run_gym.py +2 -5
- pyRDDLGym_jax/examples/run_plan.py +6 -8
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +5 -6
- pyRDDLGym_jax-0.4.dist-info/METADATA +276 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/RECORD +20 -22
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/Pong_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.2.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
|
1
|
-
from bayes_opt import BayesianOptimization
|
|
2
|
-
from bayes_opt.util import UtilityFunction
|
|
3
1
|
from copy import deepcopy
|
|
4
2
|
import csv
|
|
5
3
|
import datetime
|
|
6
|
-
import jax
|
|
7
4
|
from multiprocessing import get_context
|
|
8
|
-
import numpy as np
|
|
9
5
|
import os
|
|
10
6
|
import time
|
|
11
7
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
12
|
-
|
|
13
|
-
Kwargs = Dict[str, Any]
|
|
14
|
-
|
|
15
8
|
import warnings
|
|
16
9
|
warnings.filterwarnings("ignore")
|
|
17
10
|
|
|
11
|
+
from bayes_opt import BayesianOptimization
|
|
12
|
+
from bayes_opt.util import UtilityFunction
|
|
13
|
+
import jax
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
18
16
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
19
17
|
from pyRDDLGym.core.env import RDDLEnv
|
|
20
18
|
|
|
@@ -26,6 +24,8 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
24
|
JaxOnlineController
|
|
27
25
|
)
|
|
28
26
|
|
|
27
|
+
Kwargs = Dict[str, Any]
|
|
28
|
+
|
|
29
29
|
|
|
30
30
|
# ===============================================================================
|
|
31
31
|
#
|
|
@@ -368,7 +368,8 @@ def objective_slp(params, kwargs, key, index):
|
|
|
368
368
|
train_seconds=kwargs['timeout_training'],
|
|
369
369
|
model_params=model_params,
|
|
370
370
|
policy_hyperparams=policy_hparams,
|
|
371
|
-
|
|
371
|
+
print_summary=False,
|
|
372
|
+
print_progress=False,
|
|
372
373
|
tqdm_position=index)
|
|
373
374
|
|
|
374
375
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
@@ -499,7 +500,8 @@ def objective_replan(params, kwargs, key, index):
|
|
|
499
500
|
train_seconds=kwargs['timeout_training'],
|
|
500
501
|
model_params=model_params,
|
|
501
502
|
policy_hyperparams=policy_hparams,
|
|
502
|
-
|
|
503
|
+
print_summary=False,
|
|
504
|
+
print_progress=False,
|
|
503
505
|
tqdm_position=index)
|
|
504
506
|
|
|
505
507
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
@@ -626,7 +628,8 @@ def objective_drp(params, kwargs, key, index):
|
|
|
626
628
|
train_seconds=kwargs['timeout_training'],
|
|
627
629
|
model_params=model_params,
|
|
628
630
|
policy_hyperparams=policy_hparams,
|
|
629
|
-
|
|
631
|
+
print_summary=False,
|
|
632
|
+
print_progress=False,
|
|
630
633
|
tqdm_position=index)
|
|
631
634
|
|
|
632
635
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
@@ -6,7 +6,7 @@ tnorm_kwargs={}
|
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [
|
|
9
|
+
method_kwargs={'topology': [64, 64]}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
11
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
12
|
batch_size_train=1
|
|
@@ -14,5 +14,5 @@ batch_size_test=1
|
|
|
14
14
|
|
|
15
15
|
[Training]
|
|
16
16
|
key=42
|
|
17
|
-
epochs=
|
|
18
|
-
train_seconds=
|
|
17
|
+
epochs=6000
|
|
18
|
+
train_seconds=60
|
|
@@ -23,16 +23,13 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
|
23
23
|
def main(domain, instance, episodes=1, seed=42):
|
|
24
24
|
|
|
25
25
|
# create the environment
|
|
26
|
-
env = pyRDDLGym.make(domain, instance,
|
|
27
|
-
backend=JaxRDDLSimulator)
|
|
26
|
+
env = pyRDDLGym.make(domain, instance, backend=JaxRDDLSimulator)
|
|
28
27
|
|
|
29
|
-
#
|
|
28
|
+
# evaluate a random policy
|
|
30
29
|
agent = RandomAgent(action_space=env.action_space,
|
|
31
30
|
num_actions=env.max_allowed_actions,
|
|
32
31
|
seed=seed)
|
|
33
32
|
agent.evaluate(env, episodes=episodes, verbose=True, render=True, seed=seed)
|
|
34
|
-
|
|
35
|
-
# important when logging to save all traces
|
|
36
33
|
env.close()
|
|
37
34
|
|
|
38
35
|
|
|
@@ -13,6 +13,7 @@ where:
|
|
|
13
13
|
<domain> is the name of a domain located in the /Examples directory
|
|
14
14
|
<instance> is the instance number
|
|
15
15
|
<method> is either slp, drp, or replan
|
|
16
|
+
<episodes> is the optional number of evaluation rollouts
|
|
16
17
|
'''
|
|
17
18
|
import os
|
|
18
19
|
import sys
|
|
@@ -28,29 +29,26 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
28
29
|
def main(domain, instance, method, episodes=1):
|
|
29
30
|
|
|
30
31
|
# set up the environment
|
|
31
|
-
env = pyRDDLGym.make(domain, instance, vectorized=True
|
|
32
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
32
33
|
|
|
33
34
|
# load the config file with planner settings
|
|
34
35
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
35
36
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
36
37
|
if not os.path.isfile(config_path):
|
|
37
|
-
raise_warning(f'Config file {
|
|
38
|
-
f'using
|
|
39
|
-
'red')
|
|
38
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
39
|
+
f'using default_{method}.cfg.', 'red')
|
|
40
40
|
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
41
41
|
planner_args, _, train_args = load_config(config_path)
|
|
42
42
|
|
|
43
43
|
# create the planning algorithm
|
|
44
44
|
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
45
45
|
|
|
46
|
-
#
|
|
46
|
+
# evaluate the controller
|
|
47
47
|
if method == 'replan':
|
|
48
48
|
controller = JaxOnlineController(planner, **train_args)
|
|
49
49
|
else:
|
|
50
|
-
controller = JaxOfflineController(planner, **train_args)
|
|
51
|
-
|
|
50
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
52
51
|
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
53
|
-
|
|
54
52
|
env.close()
|
|
55
53
|
|
|
56
54
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
'''In this example, the user has the choice to run the Jax planner using an
|
|
2
|
+
optimizer from scipy.minimize.
|
|
3
|
+
|
|
4
|
+
The syntax for running this example is:
|
|
5
|
+
|
|
6
|
+
python run_scipy.py <domain> <instance> <method> [<episodes>]
|
|
7
|
+
|
|
8
|
+
where:
|
|
9
|
+
<domain> is the name of a domain located in the /Examples directory
|
|
10
|
+
<instance> is the instance number
|
|
11
|
+
<method> is the name of a method provided to scipy.optimize.minimize()
|
|
12
|
+
<episodes> is the optional number of evaluation rollouts
|
|
13
|
+
'''
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
import jax
|
|
17
|
+
from scipy.optimize import minimize
|
|
18
|
+
|
|
19
|
+
import pyRDDLGym
|
|
20
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
21
|
+
|
|
22
|
+
from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main(domain, instance, method, episodes=1):
|
|
26
|
+
|
|
27
|
+
# set up the environment
|
|
28
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
29
|
+
|
|
30
|
+
# load the config file with planner settings
|
|
31
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
32
|
+
config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
|
|
33
|
+
if not os.path.isfile(config_path):
|
|
34
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
35
|
+
f'using default_slp.cfg.', 'red')
|
|
36
|
+
config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
|
|
37
|
+
planner_args, _, train_args = load_config(config_path)
|
|
38
|
+
|
|
39
|
+
# create the planning algorithm
|
|
40
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
41
|
+
|
|
42
|
+
# find the optimal plan
|
|
43
|
+
loss_fn, grad_fn, guess, unravel_fn = planner.as_optimization_problem()
|
|
44
|
+
opt = minimize(loss_fn, jac=grad_fn, x0=guess, method=method, options={'disp': True})
|
|
45
|
+
params = unravel_fn(opt.x)
|
|
46
|
+
|
|
47
|
+
# evaluate the optimal plan
|
|
48
|
+
controller = JaxOfflineController(planner, params=params, **train_args)
|
|
49
|
+
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
50
|
+
env.close()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
args = sys.argv[1:]
|
|
55
|
+
if len(args) < 3:
|
|
56
|
+
print('python run_scipy.py <domain> <instance> <method> [<episodes>]')
|
|
57
|
+
exit(1)
|
|
58
|
+
kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
|
|
59
|
+
if len(args) >= 4: kwargs['episodes'] = int(args[3])
|
|
60
|
+
main(**kwargs)
|
|
61
|
+
|
|
@@ -31,15 +31,14 @@ from pyRDDLGym_jax.core.planner import load_config
|
|
|
31
31
|
def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
32
32
|
|
|
33
33
|
# set up the environment
|
|
34
|
-
env = pyRDDLGym.make(domain, instance, vectorized=True
|
|
34
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
35
35
|
|
|
36
36
|
# load the config file with planner settings
|
|
37
37
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
38
38
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
39
39
|
if not os.path.isfile(config_path):
|
|
40
|
-
raise_warning(f'Config file {
|
|
41
|
-
f'using
|
|
42
|
-
'red')
|
|
40
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
41
|
+
f'using default_{method}.cfg.', 'red')
|
|
43
42
|
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
44
43
|
planner_args, plan_args, train_args = load_config(config_path)
|
|
45
44
|
|
|
@@ -49,8 +48,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
49
48
|
elif method == 'drp':
|
|
50
49
|
tuning_class = JaxParameterTuningDRP
|
|
51
50
|
elif method == 'replan':
|
|
52
|
-
tuning_class = JaxParameterTuningSLPReplan
|
|
53
|
-
|
|
51
|
+
tuning_class = JaxParameterTuningSLPReplan
|
|
54
52
|
tuning = tuning_class(env=env,
|
|
55
53
|
train_epochs=train_args['epochs'],
|
|
56
54
|
timeout_training=train_args['train_seconds'],
|
|
@@ -60,6 +58,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
60
58
|
num_workers=workers,
|
|
61
59
|
gp_iters=iters)
|
|
62
60
|
|
|
61
|
+
# perform tuning and report best parameters
|
|
63
62
|
best = tuning.tune(key=train_args['key'], filename=f'gp_{method}',
|
|
64
63
|
save_plot=True)
|
|
65
64
|
print(f'best parameters found: {best}')
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: pyRDDLGym-jax
|
|
3
|
+
Version: 0.4
|
|
4
|
+
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
|
+
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
|
+
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
|
+
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
|
+
License: MIT License
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Natural Language :: English
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: pyRDDLGym >=2.0
|
|
20
|
+
Requires-Dist: tqdm >=4.66
|
|
21
|
+
Requires-Dist: bayesian-optimization >=1.4.3
|
|
22
|
+
Requires-Dist: jax >=0.4.12
|
|
23
|
+
Requires-Dist: optax >=0.1.9
|
|
24
|
+
Requires-Dist: dm-haiku >=0.0.10
|
|
25
|
+
Requires-Dist: tensorflow-probability >=0.21.0
|
|
26
|
+
|
|
27
|
+
# pyRDDLGym-jax
|
|
28
|
+
|
|
29
|
+
Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
|
|
30
|
+
|
|
31
|
+
This directory provides:
|
|
32
|
+
1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
|
|
33
|
+
2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
|
|
34
|
+
|
|
35
|
+
> [!NOTE]
|
|
36
|
+
> While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
37
|
+
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
38
|
+
|
|
39
|
+
## Contents
|
|
40
|
+
|
|
41
|
+
- [Installation](#installation)
|
|
42
|
+
- [Running from the Command Line](#running-from-the-command-line)
|
|
43
|
+
- [Running from within Python](#running-from-within-python)
|
|
44
|
+
- [Configuring the Planner](#configuring-the-planner)
|
|
45
|
+
- [Simulation](#simulation)
|
|
46
|
+
- [Manual Gradient Calculation](#manual-gradient-calculation)
|
|
47
|
+
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
|
|
52
|
+
- ``pyRDDLGym>=2.0``
|
|
53
|
+
- ``tqdm>=4.66``
|
|
54
|
+
- ``jax>=0.4.12``
|
|
55
|
+
- ``optax>=0.1.9``
|
|
56
|
+
- ``dm-haiku>=0.0.10``
|
|
57
|
+
- ``tensorflow-probability>=0.21.0``
|
|
58
|
+
|
|
59
|
+
Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
|
|
60
|
+
To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
|
|
61
|
+
|
|
62
|
+
You can install this package, together with all of its requirements, via pip:
|
|
63
|
+
|
|
64
|
+
```shell
|
|
65
|
+
pip install rddlrepository pyRDDLGym-jax
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Running from the Command Line
|
|
69
|
+
|
|
70
|
+
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
|
|
71
|
+
|
|
72
|
+
```shell
|
|
73
|
+
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
where:
|
|
77
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
78
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
79
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
80
|
+
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
81
|
+
|
|
82
|
+
The ``method`` parameter supports three possible modes:
|
|
83
|
+
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
84
|
+
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
85
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
86
|
+
|
|
87
|
+
A basic run script is also provided to run the automatic hyper-parameter tuning:
|
|
88
|
+
|
|
89
|
+
```shell
|
|
90
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
where:
|
|
94
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
|
|
95
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
|
|
96
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
97
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
98
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
99
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
100
|
+
|
|
101
|
+
For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
102
|
+
|
|
103
|
+
```shell
|
|
104
|
+
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
After several minutes of optimization, you should get a visualization as follows:
|
|
108
|
+
|
|
109
|
+
<p align="center">
|
|
110
|
+
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
111
|
+
</p>
|
|
112
|
+
|
|
113
|
+
## Running from within Python
|
|
114
|
+
|
|
115
|
+
To run the Jax planner from within a Python application, refer to the following example:
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
import pyRDDLGym
|
|
119
|
+
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
|
|
120
|
+
|
|
121
|
+
# set up the environment (note the vectorized option must be True)
|
|
122
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
123
|
+
|
|
124
|
+
# create the planning algorithm
|
|
125
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
126
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
127
|
+
|
|
128
|
+
# evaluate the planner
|
|
129
|
+
controller.evaluate(env, episodes=1, verbose=True, render=True)
|
|
130
|
+
env.close()
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
|
|
134
|
+
All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
|
|
135
|
+
The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
|
|
136
|
+
|
|
137
|
+
## Configuring the Planner
|
|
138
|
+
|
|
139
|
+
The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
|
|
140
|
+
The basic structure of a configuration file is provided below for a straight-line planner:
|
|
141
|
+
|
|
142
|
+
```ini
|
|
143
|
+
[Model]
|
|
144
|
+
logic='FuzzyLogic'
|
|
145
|
+
logic_kwargs={'weight': 20}
|
|
146
|
+
tnorm='ProductTNorm'
|
|
147
|
+
tnorm_kwargs={}
|
|
148
|
+
|
|
149
|
+
[Optimizer]
|
|
150
|
+
method='JaxStraightLinePlan'
|
|
151
|
+
method_kwargs={}
|
|
152
|
+
optimizer='rmsprop'
|
|
153
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
154
|
+
batch_size_train=1
|
|
155
|
+
batch_size_test=1
|
|
156
|
+
|
|
157
|
+
[Training]
|
|
158
|
+
key=42
|
|
159
|
+
epochs=5000
|
|
160
|
+
train_seconds=30
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
The configuration file contains three sections:
|
|
164
|
+
- ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
|
|
165
|
+
and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
|
|
166
|
+
- ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
|
|
167
|
+
- ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
|
|
168
|
+
|
|
169
|
+
For a policy network approach, simply change the ``[Optimizer]`` settings like so:
|
|
170
|
+
|
|
171
|
+
```ini
|
|
172
|
+
...
|
|
173
|
+
[Optimizer]
|
|
174
|
+
method='JaxDeepReactivePolicy'
|
|
175
|
+
method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
176
|
+
...
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
The configuration file must then be passed to the planner during initialization.
|
|
180
|
+
For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
|
|
181
|
+
|
|
182
|
+
```python
|
|
183
|
+
from pyRDDLGym_jax.core.planner import load_config
|
|
184
|
+
|
|
185
|
+
# load the config file with planner settings
|
|
186
|
+
planner_args, _, train_args = load_config("/path/to/config.cfg")
|
|
187
|
+
|
|
188
|
+
# create the planning algorithm
|
|
189
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
190
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
191
|
+
...
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
## Simulation
|
|
195
|
+
|
|
196
|
+
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
|
|
197
|
+
|
|
198
|
+
```python
|
|
199
|
+
import pyRDDLGym
|
|
200
|
+
from pyRDDLGym.core.policy import RandomAgent
|
|
201
|
+
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
202
|
+
|
|
203
|
+
# create the environment
|
|
204
|
+
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
205
|
+
|
|
206
|
+
# evaluate the random policy
|
|
207
|
+
agent = RandomAgent(action_space=env.action_space,
|
|
208
|
+
num_actions=env.max_allowed_actions)
|
|
209
|
+
agent.evaluate(env, verbose=True, render=True)
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
213
|
+
In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
|
|
214
|
+
|
|
215
|
+
## Manual Gradient Calculation
|
|
216
|
+
|
|
217
|
+
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
|
|
218
|
+
Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
219
|
+
|
|
220
|
+
```python
|
|
221
|
+
import pyRDDLGym
|
|
222
|
+
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
223
|
+
|
|
224
|
+
# set up the environment
|
|
225
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
226
|
+
|
|
227
|
+
# create the step function
|
|
228
|
+
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
229
|
+
compiled.compile()
|
|
230
|
+
step_fn = compiled.compile_transition()
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
This will return a JAX compiled (pure) function requiring the following inputs:
|
|
234
|
+
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
235
|
+
- ``actions`` is the dictionary of action fluent tensors
|
|
236
|
+
- ``subs`` is the dictionary of state-fluent and non-fluent tensors
|
|
237
|
+
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
238
|
+
|
|
239
|
+
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
240
|
+
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
241
|
+
|
|
242
|
+
Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
|
|
243
|
+
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
244
|
+
|
|
245
|
+
## Citing pyRDDLGym-jax
|
|
246
|
+
|
|
247
|
+
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
|
|
248
|
+
|
|
249
|
+
```
|
|
250
|
+
@inproceedings{gimelfarb2024jaxplan,
|
|
251
|
+
title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
|
|
252
|
+
author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
|
|
253
|
+
booktitle={34th International Conference on Automated Planning and Scheduling},
|
|
254
|
+
year={2024},
|
|
255
|
+
url={https://openreview.net/forum?id=7IKtmUpLEH}
|
|
256
|
+
}
|
|
257
|
+
```
|
|
258
|
+
|
|
259
|
+
The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
|
|
260
|
+
|
|
261
|
+
```
|
|
262
|
+
@inproceedings{patton2022distributional,
|
|
263
|
+
title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
|
|
264
|
+
author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
|
|
265
|
+
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
|
266
|
+
volume={36},
|
|
267
|
+
number={9},
|
|
268
|
+
pages={9894--9901},
|
|
269
|
+
year={2022}
|
|
270
|
+
}
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
|
|
274
|
+
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
275
|
+
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
276
|
+
|
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=rexmxcBiCOcwctw4wGvk7UxS9MfZn_1CYXp53SoLKlU,19
|
|
2
2
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
4
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/simulator.py,sha256=
|
|
7
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
3
|
+
pyRDDLGym_jax/core/compiler.py,sha256=SnDN3-J84Wv_YVHoDmfM_U4Ob8uaFLGX4vEaeWC-ERY,90037
|
|
4
|
+
pyRDDLGym_jax/core/logic.py,sha256=o1YAjMnXfi8gwb42kAigBmaf9uIYUWal9__FEkWohrk,26733
|
|
5
|
+
pyRDDLGym_jax/core/planner.py,sha256=Hrwfn88bUu1LNZcnFC5psHPzcIUbPeF4Rn1pFO6_qH0,102655
|
|
6
|
+
pyRDDLGym_jax/core/simulator.py,sha256=hWv6pr-4V-SSCzBYgdIPmKdUDMalft-Zh6dzOo5O9-0,8331
|
|
7
|
+
pyRDDLGym_jax/core/tuning.py,sha256=D_kD8wjqMroCdtjE9eksR2UqrqXJqazsAKrMEHwPxYM,29589
|
|
8
8
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
10
|
-
pyRDDLGym_jax/examples/run_gym.py,sha256=
|
|
11
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
12
|
-
pyRDDLGym_jax/examples/
|
|
10
|
+
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
11
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=OENf8s-SrMlh7CYXNhanQiau35b4atLBJMNjgP88DCg,2463
|
|
12
|
+
pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
|
|
13
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=-M4KoBpg5lshQ4mmU0cnLs2i7-ldSIr_OcxHK7YA6bw,3273
|
|
13
14
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=pbkz6ccgk5dHXp7cfYbZNFyJobpGyxUZleCy4fvlmaU,336
|
|
14
15
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=OswO9YD4Xh1pw3R3LkUBb67WLtj5XlE3qnMQ5CKwPsM,332
|
|
15
16
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=FxZ4xcg2j2PzeH-wUseRR280juQN5bJjoyt6PtI1W7c,329
|
|
16
|
-
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=
|
|
17
|
+
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=FTGFwRAGyeRrbDMh_FV8iv8ZHrlj3Htju4pfPNmKIcw,336
|
|
17
18
|
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=wjtz86_Gz0RfQu3bbrz56PTXL8JMernINx7AtJuZCPs,314
|
|
18
|
-
pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=
|
|
19
|
+
pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=C_0BFyhGXbtF7N4vyeua2XkORbkj10HELC1GpzM0Uh4,415
|
|
19
20
|
pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg,sha256=Yb4tFzUOj4epCCsofXAZo70lm5C2KzPIzI5PQHsa_Vk,429
|
|
20
21
|
pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=e7j-1Z66o7F-KZDSf2e8TQRWwkXOPRwrRFkIavK8G7g,327
|
|
21
22
|
pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=Z6CxaOxHv4oF6nW7SfSn_HshlQGDlNCPGASTnDTdL7Q,327
|
|
22
|
-
pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=
|
|
23
|
-
pyRDDLGym_jax/examples/configs/Pong_slp.cfg,sha256=S45mBj5hTEshdeJ4rdRaty6YliggtEMkLQV6IYxEkyU,315
|
|
23
|
+
pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=Uy1mrX-AZMS-KBAhWXJ3c_QAhd4bRSWttDoFGYQ08lQ,315
|
|
24
24
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=SM5_U4RwvvucHVAOdMG4vqH0Eg43f3WX9ZlV6aFPgTw,341
|
|
25
25
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=lcqQ7P7X4qAbMlpkKKuYGn2luSZH-yFB7oi-eHj9Qng,332
|
|
26
26
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=kG1-02ScmwsEwX7QIAZTD7si90Mb06b79G5oqcMQ9Hg,316
|
|
@@ -29,18 +29,16 @@ pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=9QNl58PyoJYhmwvrhzUxlLE
|
|
|
29
29
|
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=rrubYvC1q7Ff0ADV0GXtLw-rD9E4m7qfR66qxdYNTD8,339
|
|
30
30
|
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=DAb-J2KwvJXViRRSHZe8aJwZiPljC28HtrKJPieeUCY,331
|
|
31
31
|
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=QwKzCAFaErrTCHaJwDPLOxPHpNGNuAKMUoZjLLnMrNc,314
|
|
32
|
-
pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg,sha256=vU_m6KjfNfaPuYosFdAWeYiV1zQGd6eNA17Yn5QB_BI,319
|
|
33
|
-
pyRDDLGym_jax/examples/configs/Traffic_slp.cfg,sha256=03scuHAl6032YhyYy0w5MLMbTibhdbUZFHLhH2WWaPI,370
|
|
34
32
|
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=QiJCJYOrdXXZfOTuPleGswREFxjGlqQSA0rw00YJWWI,318
|
|
35
33
|
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=PGkgll7h5vhSF13JScKoQ-vpWaAGNJ_PUEhK7jEjNx4,340
|
|
36
34
|
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=kEDAwsJQ_t9WPzPhIxfS0hRtgOhtFdJFfmPtTTJuwUE,454
|
|
37
35
|
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=w2wipsA8PE5OBkYVIKajjtCOtiHqmMeY3XQVPAApwFk,371
|
|
38
36
|
pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
-
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=
|
|
37
|
+
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
|
|
40
38
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
|
|
41
|
-
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=
|
|
42
|
-
pyRDDLGym_jax-0.
|
|
43
|
-
pyRDDLGym_jax-0.
|
|
44
|
-
pyRDDLGym_jax-0.
|
|
45
|
-
pyRDDLGym_jax-0.
|
|
46
|
-
pyRDDLGym_jax-0.
|
|
39
|
+
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
|
|
40
|
+
pyRDDLGym_jax-0.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
41
|
+
pyRDDLGym_jax-0.4.dist-info/METADATA,sha256=-Kf8PLxf_7MiiYXzlZAf31kV1pT-Rurc7QY7dT3Fwk0,12857
|
|
42
|
+
pyRDDLGym_jax-0.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
43
|
+
pyRDDLGym_jax-0.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
44
|
+
pyRDDLGym_jax-0.4.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
[Model]
|
|
2
|
-
logic='FuzzyLogic'
|
|
3
|
-
logic_kwargs={'weight': 1.0}
|
|
4
|
-
tnorm='ProductTNorm'
|
|
5
|
-
tnorm_kwargs={}
|
|
6
|
-
|
|
7
|
-
[Optimizer]
|
|
8
|
-
method='JaxStraightLinePlan'
|
|
9
|
-
method_kwargs={}
|
|
10
|
-
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
|
-
batch_size_train=1
|
|
13
|
-
batch_size_test=1
|
|
14
|
-
|
|
15
|
-
[Training]
|
|
16
|
-
key=42
|
|
17
|
-
epochs=2000
|
|
18
|
-
train_seconds=30
|