pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +444 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +965 -394
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +29 -15
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
- pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
- pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +3 -7
- pyRDDLGym_jax/examples/run_plan.py +10 -5
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +8 -3
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import time
|
|
3
|
-
from typing import Dict
|
|
3
|
+
from typing import Dict, Optional
|
|
4
4
|
|
|
5
5
|
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
6
6
|
from pyRDDLGym.core.debug.exception import (
|
|
@@ -20,9 +20,9 @@ Args = Dict[str, Value]
|
|
|
20
20
|
class JaxRDDLSimulator(RDDLSimulator):
|
|
21
21
|
|
|
22
22
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
23
|
-
key: jax.random.PRNGKey=None,
|
|
23
|
+
key: Optional[jax.random.PRNGKey]=None,
|
|
24
24
|
raise_error: bool=True,
|
|
25
|
-
logger: Logger=None,
|
|
25
|
+
logger: Optional[Logger]=None,
|
|
26
26
|
keep_tensors: bool=False,
|
|
27
27
|
**compiler_args) -> None:
|
|
28
28
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
@@ -56,10 +56,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
56
56
|
rddl = self.rddl
|
|
57
57
|
|
|
58
58
|
# compilation
|
|
59
|
-
if self.logger is not None:
|
|
60
|
-
self.logger.clear()
|
|
61
59
|
compiled = JaxRDDLCompiler(rddl, logger=self.logger, **self.compiler_args)
|
|
62
|
-
compiled.compile(log_jax_expr=True)
|
|
60
|
+
compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
|
|
63
61
|
|
|
64
62
|
self.init_values = compiled.init_values
|
|
65
63
|
self.levels = compiled.levels
|
|
@@ -96,7 +94,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
96
94
|
self.precond_names = [f'Precondition {i}' for i in range(len(rddl.preconditions))]
|
|
97
95
|
self.terminal_names = [f'Termination {i}' for i in range(len(rddl.terminations))]
|
|
98
96
|
|
|
99
|
-
def handle_error_code(self, error, msg) -> None:
|
|
97
|
+
def handle_error_code(self, error: int, msg: str) -> None:
|
|
100
98
|
if self.raise_error:
|
|
101
99
|
errors = JaxRDDLCompiler.get_error_messages(error)
|
|
102
100
|
if errors:
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -8,7 +8,9 @@ from multiprocessing import get_context
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import os
|
|
10
10
|
import time
|
|
11
|
-
from typing import Callable, Dict, Tuple
|
|
11
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
Kwargs = Dict[str, Any]
|
|
12
14
|
|
|
13
15
|
import warnings
|
|
14
16
|
warnings.filterwarnings("ignore")
|
|
@@ -45,15 +47,15 @@ class JaxParameterTuning:
|
|
|
45
47
|
timeout_tuning: float=np.inf,
|
|
46
48
|
eval_trials: int=5,
|
|
47
49
|
verbose: bool=True,
|
|
48
|
-
planner_kwargs:
|
|
49
|
-
plan_kwargs:
|
|
50
|
+
planner_kwargs: Optional[Kwargs]=None,
|
|
51
|
+
plan_kwargs: Optional[Kwargs]=None,
|
|
50
52
|
pool_context: str='spawn',
|
|
51
53
|
num_workers: int=1,
|
|
52
54
|
poll_frequency: float=0.2,
|
|
53
55
|
gp_iters: int=25,
|
|
54
|
-
acquisition=None,
|
|
55
|
-
gp_init_kwargs:
|
|
56
|
-
gp_params:
|
|
56
|
+
acquisition: Optional[UtilityFunction]=None,
|
|
57
|
+
gp_init_kwargs: Optional[Kwargs]=None,
|
|
58
|
+
gp_params: Optional[Kwargs]=None) -> None:
|
|
57
59
|
'''Creates a new instance for tuning hyper-parameters for Jax planners
|
|
58
60
|
on the given RDDL domain and instance.
|
|
59
61
|
|
|
@@ -93,13 +95,21 @@ class JaxParameterTuning:
|
|
|
93
95
|
self.timeout_tuning = timeout_tuning
|
|
94
96
|
self.eval_trials = eval_trials
|
|
95
97
|
self.verbose = verbose
|
|
98
|
+
if planner_kwargs is None:
|
|
99
|
+
planner_kwargs = {}
|
|
96
100
|
self.planner_kwargs = planner_kwargs
|
|
101
|
+
if plan_kwargs is None:
|
|
102
|
+
plan_kwargs = {}
|
|
97
103
|
self.plan_kwargs = plan_kwargs
|
|
98
104
|
self.pool_context = pool_context
|
|
99
105
|
self.num_workers = num_workers
|
|
100
106
|
self.poll_frequency = poll_frequency
|
|
101
107
|
self.gp_iters = gp_iters
|
|
108
|
+
if gp_init_kwargs is None:
|
|
109
|
+
gp_init_kwargs = {}
|
|
102
110
|
self.gp_init_kwargs = gp_init_kwargs
|
|
111
|
+
if gp_params is None:
|
|
112
|
+
gp_params = {'n_restarts_optimizer': 10}
|
|
103
113
|
self.gp_params = gp_params
|
|
104
114
|
|
|
105
115
|
# create acquisition function
|
|
@@ -109,7 +119,7 @@ class JaxParameterTuning:
|
|
|
109
119
|
acquisition, self.acq_args = JaxParameterTuning._annealing_utility(num_samples)
|
|
110
120
|
self.acquisition = acquisition
|
|
111
121
|
|
|
112
|
-
def summarize_hyperparameters(self):
|
|
122
|
+
def summarize_hyperparameters(self) -> None:
|
|
113
123
|
print(f'hyperparameter optimizer parameters:\n'
|
|
114
124
|
f' tuned_hyper_parameters ={self.hyperparams_dict}\n'
|
|
115
125
|
f' initialization_args ={self.gp_init_kwargs}\n'
|
|
@@ -150,8 +160,9 @@ class JaxParameterTuning:
|
|
|
150
160
|
pid = os.getpid()
|
|
151
161
|
return index, pid, params, target
|
|
152
162
|
|
|
153
|
-
def tune(self, key: jax.random.PRNGKey,
|
|
154
|
-
|
|
163
|
+
def tune(self, key: jax.random.PRNGKey,
|
|
164
|
+
filename: str,
|
|
165
|
+
save_plot: bool=False) -> Dict[str, Any]:
|
|
155
166
|
'''Tunes the hyper-parameters for Jax planner, returns the best found.'''
|
|
156
167
|
self.summarize_hyperparameters()
|
|
157
168
|
|
|
@@ -357,14 +368,15 @@ def objective_slp(params, kwargs, key, index):
|
|
|
357
368
|
train_seconds=kwargs['timeout_training'],
|
|
358
369
|
model_params=model_params,
|
|
359
370
|
policy_hyperparams=policy_hparams,
|
|
360
|
-
|
|
371
|
+
print_summary=False,
|
|
372
|
+
print_progress=False,
|
|
361
373
|
tqdm_position=index)
|
|
362
374
|
|
|
363
375
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
364
376
|
env = RDDLEnv(domain=kwargs['domain'],
|
|
365
377
|
instance=kwargs['instance'],
|
|
366
378
|
vectorized=True,
|
|
367
|
-
enforce_action_constraints=
|
|
379
|
+
enforce_action_constraints=False)
|
|
368
380
|
|
|
369
381
|
# perform training
|
|
370
382
|
average_reward = 0.0
|
|
@@ -488,14 +500,15 @@ def objective_replan(params, kwargs, key, index):
|
|
|
488
500
|
train_seconds=kwargs['timeout_training'],
|
|
489
501
|
model_params=model_params,
|
|
490
502
|
policy_hyperparams=policy_hparams,
|
|
491
|
-
|
|
503
|
+
print_summary=False,
|
|
504
|
+
print_progress=False,
|
|
492
505
|
tqdm_position=index)
|
|
493
506
|
|
|
494
507
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
495
508
|
env = RDDLEnv(domain=kwargs['domain'],
|
|
496
509
|
instance=kwargs['instance'],
|
|
497
510
|
vectorized=True,
|
|
498
|
-
enforce_action_constraints=
|
|
511
|
+
enforce_action_constraints=False)
|
|
499
512
|
|
|
500
513
|
# perform training
|
|
501
514
|
average_reward = 0.0
|
|
@@ -615,14 +628,15 @@ def objective_drp(params, kwargs, key, index):
|
|
|
615
628
|
train_seconds=kwargs['timeout_training'],
|
|
616
629
|
model_params=model_params,
|
|
617
630
|
policy_hyperparams=policy_hparams,
|
|
618
|
-
|
|
631
|
+
print_summary=False,
|
|
632
|
+
print_progress=False,
|
|
619
633
|
tqdm_position=index)
|
|
620
634
|
|
|
621
635
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
622
636
|
env = RDDLEnv(domain=kwargs['domain'],
|
|
623
637
|
instance=kwargs['instance'],
|
|
624
638
|
vectorized=True,
|
|
625
|
-
enforce_action_constraints=
|
|
639
|
+
enforce_action_constraints=False)
|
|
626
640
|
|
|
627
641
|
# perform training
|
|
628
642
|
average_reward = 0.0
|
pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg}
RENAMED
|
@@ -8,12 +8,11 @@ tnorm_kwargs={}
|
|
|
8
8
|
method='JaxDeepReactivePolicy'
|
|
9
9
|
method_kwargs={'topology': [32, 32]}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.005}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
|
-
clip_grad=1.0
|
|
15
14
|
|
|
16
15
|
[Training]
|
|
17
16
|
key=42
|
|
18
|
-
epochs=
|
|
17
|
+
epochs=2000
|
|
19
18
|
train_seconds=30
|
|
@@ -6,13 +6,13 @@ tnorm_kwargs={}
|
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [
|
|
9
|
+
method_kwargs={'topology': [64, 64]}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
|
|
15
15
|
[Training]
|
|
16
16
|
key=42
|
|
17
|
-
epochs=
|
|
18
|
-
train_seconds=
|
|
17
|
+
epochs=6000
|
|
18
|
+
train_seconds=60
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
logic_kwargs={'weight': 10}
|
|
4
|
+
tnorm='ProductTNorm'
|
|
5
|
+
tnorm_kwargs={}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxStraightLinePlan'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': 1.0}
|
|
12
|
+
batch_size_train=1
|
|
13
|
+
batch_size_test=1
|
|
14
|
+
clip_grad=1.0
|
|
15
|
+
|
|
16
|
+
[Training]
|
|
17
|
+
key=42
|
|
18
|
+
epochs=1000
|
|
19
|
+
train_seconds=30
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
logic_kwargs={'weight':
|
|
3
|
+
logic_kwargs={'weight': 100}
|
|
4
4
|
tnorm='ProductTNorm'
|
|
5
5
|
tnorm_kwargs={}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
|
-
method='
|
|
9
|
-
method_kwargs={}
|
|
8
|
+
method='JaxDeepReactivePolicy'
|
|
9
|
+
method_kwargs={'topology': [256, 128], 'activation': 'tanh'}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
11
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
12
|
batch_size_train=1
|
|
@@ -14,5 +14,5 @@ batch_size_test=1
|
|
|
14
14
|
|
|
15
15
|
[Training]
|
|
16
16
|
key=42
|
|
17
|
-
epochs=
|
|
18
|
-
train_seconds=
|
|
17
|
+
epochs=100000
|
|
18
|
+
train_seconds=360
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
logic_kwargs={'weight': 10}
|
|
4
|
+
tnorm='ProductTNorm'
|
|
5
|
+
tnorm_kwargs={}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxDeepReactivePolicy'
|
|
9
|
+
method_kwargs={'topology': [64, 32]}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.0002}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
|
|
15
|
+
[Training]
|
|
16
|
+
key=42
|
|
17
|
+
epochs=5000
|
|
18
|
+
train_seconds=60
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
logic_kwargs={'weight': 20}
|
|
4
|
+
tnorm='ProductTNorm'
|
|
5
|
+
tnorm_kwargs={}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxDeepReactivePolicy'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.0001}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
|
|
15
|
+
[Training]
|
|
16
|
+
key=42
|
|
17
|
+
epochs=30000
|
|
18
|
+
train_seconds=60
|
|
19
|
+
policy_hyperparams=2.0
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
logic_kwargs={'weight': 20}
|
|
4
|
+
tnorm='ProductTNorm'
|
|
5
|
+
tnorm_kwargs={}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxStraightLinePlan'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.01}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
rollout_horizon=5
|
|
15
|
+
|
|
16
|
+
[Training]
|
|
17
|
+
key=42
|
|
18
|
+
epochs=2000
|
|
19
|
+
train_seconds=1
|
|
20
|
+
policy_hyperparams=2.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
logic_kwargs={'weight': 20}
|
|
4
|
+
tnorm='ProductTNorm'
|
|
5
|
+
tnorm_kwargs={}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxStraightLinePlan'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.01}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
|
|
15
|
+
[Training]
|
|
16
|
+
key=42
|
|
17
|
+
epochs=30000
|
|
18
|
+
train_seconds=60
|
|
19
|
+
policy_hyperparams=2.0
|
|
@@ -91,7 +91,7 @@ def main():
|
|
|
91
91
|
my_args = [jax.random.PRNGKey(42), params, None, subs, compiler.model_params]
|
|
92
92
|
|
|
93
93
|
# print the fluents over the trajectory, return and gradient
|
|
94
|
-
print(step_fn(*my_args)['
|
|
94
|
+
print(step_fn(*my_args)['fluents'])
|
|
95
95
|
print(sum_of_rewards(*my_args))
|
|
96
96
|
print(jax.grad(sum_of_rewards, argnums=1)(*my_args))
|
|
97
97
|
|
|
@@ -23,17 +23,13 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
|
23
23
|
def main(domain, instance, episodes=1, seed=42):
|
|
24
24
|
|
|
25
25
|
# create the environment
|
|
26
|
-
env = pyRDDLGym.make(domain, instance,
|
|
27
|
-
backend=JaxRDDLSimulator)
|
|
28
|
-
env.seed(seed)
|
|
26
|
+
env = pyRDDLGym.make(domain, instance, backend=JaxRDDLSimulator)
|
|
29
27
|
|
|
30
|
-
#
|
|
28
|
+
# evaluate a random policy
|
|
31
29
|
agent = RandomAgent(action_space=env.action_space,
|
|
32
30
|
num_actions=env.max_allowed_actions,
|
|
33
31
|
seed=seed)
|
|
34
|
-
agent.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
35
|
-
|
|
36
|
-
# important when logging to save all traces
|
|
32
|
+
agent.evaluate(env, episodes=episodes, verbose=True, render=True, seed=seed)
|
|
37
33
|
env.close()
|
|
38
34
|
|
|
39
35
|
|
|
@@ -13,11 +13,14 @@ where:
|
|
|
13
13
|
<domain> is the name of a domain located in the /Examples directory
|
|
14
14
|
<instance> is the instance number
|
|
15
15
|
<method> is either slp, drp, or replan
|
|
16
|
+
<episodes> is the optional number of evaluation rollouts
|
|
16
17
|
'''
|
|
17
18
|
import os
|
|
18
19
|
import sys
|
|
19
20
|
|
|
20
21
|
import pyRDDLGym
|
|
22
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
23
|
+
|
|
21
24
|
from pyRDDLGym_jax.core.planner import (
|
|
22
25
|
load_config, JaxBackpropPlanner, JaxOfflineController, JaxOnlineController
|
|
23
26
|
)
|
|
@@ -26,24 +29,26 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
29
|
def main(domain, instance, method, episodes=1):
|
|
27
30
|
|
|
28
31
|
# set up the environment
|
|
29
|
-
env = pyRDDLGym.make(domain, instance, vectorized=True
|
|
32
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
30
33
|
|
|
31
34
|
# load the config file with planner settings
|
|
32
35
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
33
36
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
37
|
+
if not os.path.isfile(config_path):
|
|
38
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
39
|
+
f'using default_{method}.cfg.', 'red')
|
|
40
|
+
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
34
41
|
planner_args, _, train_args = load_config(config_path)
|
|
35
42
|
|
|
36
43
|
# create the planning algorithm
|
|
37
44
|
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
38
45
|
|
|
39
|
-
#
|
|
46
|
+
# evaluate the controller
|
|
40
47
|
if method == 'replan':
|
|
41
48
|
controller = JaxOnlineController(planner, **train_args)
|
|
42
49
|
else:
|
|
43
|
-
controller = JaxOfflineController(planner, **train_args)
|
|
44
|
-
|
|
50
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
45
51
|
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
46
|
-
|
|
47
52
|
env.close()
|
|
48
53
|
|
|
49
54
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
'''In this example, the user has the choice to run the Jax planner using an
|
|
2
|
+
optimizer from scipy.minimize.
|
|
3
|
+
|
|
4
|
+
The syntax for running this example is:
|
|
5
|
+
|
|
6
|
+
python run_scipy.py <domain> <instance> <method> [<episodes>]
|
|
7
|
+
|
|
8
|
+
where:
|
|
9
|
+
<domain> is the name of a domain located in the /Examples directory
|
|
10
|
+
<instance> is the instance number
|
|
11
|
+
<method> is the name of a method provided to scipy.optimize.minimize()
|
|
12
|
+
<episodes> is the optional number of evaluation rollouts
|
|
13
|
+
'''
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
import jax
|
|
17
|
+
from scipy.optimize import minimize
|
|
18
|
+
|
|
19
|
+
import pyRDDLGym
|
|
20
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
21
|
+
|
|
22
|
+
from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main(domain, instance, method, episodes=1):
|
|
26
|
+
|
|
27
|
+
# set up the environment
|
|
28
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
29
|
+
|
|
30
|
+
# load the config file with planner settings
|
|
31
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
32
|
+
config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
|
|
33
|
+
if not os.path.isfile(config_path):
|
|
34
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
35
|
+
f'using default_slp.cfg.', 'red')
|
|
36
|
+
config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
|
|
37
|
+
planner_args, _, train_args = load_config(config_path)
|
|
38
|
+
|
|
39
|
+
# create the planning algorithm
|
|
40
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
41
|
+
|
|
42
|
+
# find the optimal plan
|
|
43
|
+
loss_fn, grad_fn, guess, unravel_fn = planner.as_optimization_problem()
|
|
44
|
+
opt = minimize(loss_fn, jac=grad_fn, x0=guess, method=method, options={'disp': True})
|
|
45
|
+
params = unravel_fn(opt.x)
|
|
46
|
+
|
|
47
|
+
# evaluate the optimal plan
|
|
48
|
+
controller = JaxOfflineController(planner, params=params, **train_args)
|
|
49
|
+
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
|
|
50
|
+
env.close()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
args = sys.argv[1:]
|
|
55
|
+
if len(args) < 3:
|
|
56
|
+
print('python run_scipy.py <domain> <instance> <method> [<episodes>]')
|
|
57
|
+
exit(1)
|
|
58
|
+
kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
|
|
59
|
+
if len(args) >= 4: kwargs['episodes'] = int(args[3])
|
|
60
|
+
main(**kwargs)
|
|
61
|
+
|
|
@@ -20,6 +20,7 @@ import os
|
|
|
20
20
|
import sys
|
|
21
21
|
|
|
22
22
|
import pyRDDLGym
|
|
23
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
23
24
|
|
|
24
25
|
from pyRDDLGym_jax.core.tuning import (
|
|
25
26
|
JaxParameterTuningDRP, JaxParameterTuningSLP, JaxParameterTuningSLPReplan
|
|
@@ -30,11 +31,15 @@ from pyRDDLGym_jax.core.planner import load_config
|
|
|
30
31
|
def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
31
32
|
|
|
32
33
|
# set up the environment
|
|
33
|
-
env = pyRDDLGym.make(domain, instance, vectorized=True
|
|
34
|
+
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
34
35
|
|
|
35
36
|
# load the config file with planner settings
|
|
36
37
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
37
38
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
39
|
+
if not os.path.isfile(config_path):
|
|
40
|
+
raise_warning(f'Config file {config_path} was not found, '
|
|
41
|
+
f'using default_{method}.cfg.', 'red')
|
|
42
|
+
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
38
43
|
planner_args, plan_args, train_args = load_config(config_path)
|
|
39
44
|
|
|
40
45
|
# define algorithm to perform tuning
|
|
@@ -43,8 +48,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
43
48
|
elif method == 'drp':
|
|
44
49
|
tuning_class = JaxParameterTuningDRP
|
|
45
50
|
elif method == 'replan':
|
|
46
|
-
tuning_class = JaxParameterTuningSLPReplan
|
|
47
|
-
|
|
51
|
+
tuning_class = JaxParameterTuningSLPReplan
|
|
48
52
|
tuning = tuning_class(env=env,
|
|
49
53
|
train_epochs=train_args['epochs'],
|
|
50
54
|
timeout_training=train_args['train_seconds'],
|
|
@@ -54,6 +58,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
54
58
|
num_workers=workers,
|
|
55
59
|
gp_iters=iters)
|
|
56
60
|
|
|
61
|
+
# perform tuning and report best parameters
|
|
57
62
|
best = tuning.tune(key=train_args['key'], filename=f'gp_{method}',
|
|
58
63
|
save_plot=True)
|
|
59
64
|
print(f'best parameters found: {best}')
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3
|
|
4
4
|
Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=Cl7DWkrPP64Ofc2ILXnudFOdnCuKs2p0Pm7ykZOOPh4,19
|
|
2
|
+
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
pyRDDLGym_jax/core/compiler.py,sha256=m7p0CHOU4Wma0cKMu_WQwfoieIQ2pXD68hZ8BFJ970A,89103
|
|
4
|
+
pyRDDLGym_jax/core/logic.py,sha256=zujSHiR5KhTO81E5Zn8Gy_xSzVzfDskFCGvZygFRdMI,21930
|
|
5
|
+
pyRDDLGym_jax/core/planner.py,sha256=1BtU1G3rihRZaMfNu0VtbSl1LXEXu6pT75EkF6-WVnM,101827
|
|
6
|
+
pyRDDLGym_jax/core/simulator.py,sha256=fp6bep3XwwBWED0w7_4qhiwDjkSka6B2prwdNcPRCMc,8329
|
|
7
|
+
pyRDDLGym_jax/core/tuning.py,sha256=Dv0YyOgGnej-zdVymWdkVg0MZjm2lNRfr7gySzFOeow,29589
|
|
8
|
+
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
10
|
+
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
11
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=OENf8s-SrMlh7CYXNhanQiau35b4atLBJMNjgP88DCg,2463
|
|
12
|
+
pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
|
|
13
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=-M4KoBpg5lshQ4mmU0cnLs2i7-ldSIr_OcxHK7YA6bw,3273
|
|
14
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=pbkz6ccgk5dHXp7cfYbZNFyJobpGyxUZleCy4fvlmaU,336
|
|
15
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=OswO9YD4Xh1pw3R3LkUBb67WLtj5XlE3qnMQ5CKwPsM,332
|
|
16
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=FxZ4xcg2j2PzeH-wUseRR280juQN5bJjoyt6PtI1W7c,329
|
|
17
|
+
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=FTGFwRAGyeRrbDMh_FV8iv8ZHrlj3Htju4pfPNmKIcw,336
|
|
18
|
+
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=wjtz86_Gz0RfQu3bbrz56PTXL8JMernINx7AtJuZCPs,314
|
|
19
|
+
pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=C_0BFyhGXbtF7N4vyeua2XkORbkj10HELC1GpzM0Uh4,415
|
|
20
|
+
pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg,sha256=Yb4tFzUOj4epCCsofXAZo70lm5C2KzPIzI5PQHsa_Vk,429
|
|
21
|
+
pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=e7j-1Z66o7F-KZDSf2e8TQRWwkXOPRwrRFkIavK8G7g,327
|
|
22
|
+
pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=Z6CxaOxHv4oF6nW7SfSn_HshlQGDlNCPGASTnDTdL7Q,327
|
|
23
|
+
pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=Uy1mrX-AZMS-KBAhWXJ3c_QAhd4bRSWttDoFGYQ08lQ,315
|
|
24
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=SM5_U4RwvvucHVAOdMG4vqH0Eg43f3WX9ZlV6aFPgTw,341
|
|
25
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=lcqQ7P7X4qAbMlpkKKuYGn2luSZH-yFB7oi-eHj9Qng,332
|
|
26
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=kG1-02ScmwsEwX7QIAZTD7si90Mb06b79G5oqcMQ9Hg,316
|
|
27
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=yGMBWiVZT8KdZ1PhQ4kIxPvnjht1ss0UheTV-Nt9oaA,364
|
|
28
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=9QNl58PyoJYhmwvrhzUxlLEy8vGbmwE6lRuOdvhLjGQ,317
|
|
29
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=rrubYvC1q7Ff0ADV0GXtLw-rD9E4m7qfR66qxdYNTD8,339
|
|
30
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=DAb-J2KwvJXViRRSHZe8aJwZiPljC28HtrKJPieeUCY,331
|
|
31
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=QwKzCAFaErrTCHaJwDPLOxPHpNGNuAKMUoZjLLnMrNc,314
|
|
32
|
+
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=QiJCJYOrdXXZfOTuPleGswREFxjGlqQSA0rw00YJWWI,318
|
|
33
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=PGkgll7h5vhSF13JScKoQ-vpWaAGNJ_PUEhK7jEjNx4,340
|
|
34
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=kEDAwsJQ_t9WPzPhIxfS0hRtgOhtFdJFfmPtTTJuwUE,454
|
|
35
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=w2wipsA8PE5OBkYVIKajjtCOtiHqmMeY3XQVPAApwFk,371
|
|
36
|
+
pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
+
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
|
|
38
|
+
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
|
|
39
|
+
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
|
|
40
|
+
pyRDDLGym_jax-0.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
41
|
+
pyRDDLGym_jax-0.3.dist-info/METADATA,sha256=e_1MlMdQoqQHW-KA2OSIZzIAQyfe-jDtMOxkIyhmLmI,1085
|
|
42
|
+
pyRDDLGym_jax-0.3.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
|
|
43
|
+
pyRDDLGym_jax-0.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
44
|
+
pyRDDLGym_jax-0.3.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
[Model]
|
|
2
|
-
logic='FuzzyLogic'
|
|
3
|
-
logic_kwargs={'weight': 10.0}
|
|
4
|
-
tnorm='ProductTNorm'
|
|
5
|
-
tnorm_kwargs={}
|
|
6
|
-
|
|
7
|
-
[Optimizer]
|
|
8
|
-
method='JaxStraightLinePlan'
|
|
9
|
-
method_kwargs={}
|
|
10
|
-
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.005}
|
|
12
|
-
batch_size_train=8
|
|
13
|
-
batch_size_test=8
|
|
14
|
-
|
|
15
|
-
[Training]
|
|
16
|
-
key=42
|
|
17
|
-
epochs=10000
|
|
18
|
-
train_seconds=90
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
[Model]
|
|
2
|
-
logic='FuzzyLogic'
|
|
3
|
-
logic_kwargs={'weight': 1000}
|
|
4
|
-
tnorm='ProductTNorm'
|
|
5
|
-
tnorm_kwargs={}
|
|
6
|
-
|
|
7
|
-
[Optimizer]
|
|
8
|
-
method='JaxStraightLinePlan'
|
|
9
|
-
method_kwargs={}
|
|
10
|
-
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
|
-
batch_size_train=16
|
|
13
|
-
batch_size_test=16
|
|
14
|
-
clip_grad=1.0
|
|
15
|
-
|
|
16
|
-
[Training]
|
|
17
|
-
key=42
|
|
18
|
-
epochs=200
|
|
19
|
-
train_seconds=30
|
|
20
|
-
policy_hyperparams={'advance': 10.0}
|