pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 20}
|
|
4
|
+
rounding_kwargs={'weight': 20}
|
|
5
|
+
control_kwargs={'weight': 20}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [32,
|
|
9
|
+
method_kwargs={'topology': [32, 16]}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
11
|
optimizer_kwargs={'learning_rate': 0.005}
|
|
12
12
|
batch_size_train=1
|
|
@@ -14,5 +14,4 @@ batch_size_test=1
|
|
|
14
14
|
|
|
15
15
|
[Training]
|
|
16
16
|
key=42
|
|
17
|
-
epochs=
|
|
18
|
-
train_seconds=30
|
|
17
|
+
epochs=1000
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 50}
|
|
4
|
+
rounding_kwargs={'weight': 50}
|
|
5
|
+
control_kwargs={'weight': 50}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
@@ -15,6 +15,5 @@ rollout_horizon=30
|
|
|
15
15
|
|
|
16
16
|
[Training]
|
|
17
17
|
key=42
|
|
18
|
-
|
|
19
|
-
train_seconds=1
|
|
18
|
+
train_seconds=0.5
|
|
20
19
|
print_summary=False
|
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 30}
|
|
4
|
+
rounding_kwargs={'weight': 30}
|
|
5
|
+
control_kwargs={'weight': 30}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
9
9
|
method_kwargs={}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.002}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
clip_grad=1.0
|
|
15
15
|
|
|
16
16
|
[Training]
|
|
17
17
|
key=42
|
|
18
|
-
epochs=5000
|
|
19
|
-
train_seconds=30
|
|
18
|
+
epochs=5000
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 5}
|
|
4
|
+
rounding_kwargs={'weight': 5}
|
|
5
|
+
control_kwargs={'weight': 5}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
9
9
|
method_kwargs={}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.02}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 50}
|
|
4
|
+
rounding_kwargs={'weight': 50}
|
|
5
|
+
control_kwargs={'weight': 50}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
9
9
|
method_kwargs={}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.03}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 10}
|
|
4
|
+
rounding_kwargs={'weight': 10}
|
|
5
|
+
control_kwargs={'weight': 10}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
@@ -14,5 +14,5 @@ batch_size_test=32
|
|
|
14
14
|
|
|
15
15
|
[Training]
|
|
16
16
|
key=42
|
|
17
|
-
epochs=
|
|
18
|
-
train_seconds=
|
|
17
|
+
epochs=2000
|
|
18
|
+
train_seconds=30
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
comparison_kwargs={'weight': 1}
|
|
4
|
+
rounding_kwargs={'weight': 1}
|
|
5
|
+
control_kwargs={'weight': 1}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
9
9
|
method_kwargs={}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.0003}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
4
|
+
rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
5
|
+
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxDeepReactivePolicy'
|
|
9
|
+
method_kwargs={'topology': [LAYER1_TUNE, LAYER2_TUNE]}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
|
|
15
|
+
[Training]
|
|
16
|
+
train_seconds=30
|
|
17
|
+
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
18
|
+
print_summary=False
|
|
19
|
+
print_progress=False
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
4
|
+
rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
5
|
+
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxStraightLinePlan'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
rollout_horizon=ROLLOUT_HORIZON_TUNE
|
|
15
|
+
|
|
16
|
+
[Training]
|
|
17
|
+
train_seconds=1
|
|
18
|
+
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
19
|
+
print_summary=False
|
|
20
|
+
print_progress=False
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[Model]
|
|
2
|
+
logic='FuzzyLogic'
|
|
3
|
+
comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
4
|
+
rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
5
|
+
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
6
|
+
|
|
7
|
+
[Optimizer]
|
|
8
|
+
method='JaxStraightLinePlan'
|
|
9
|
+
method_kwargs={}
|
|
10
|
+
optimizer='rmsprop'
|
|
11
|
+
optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
12
|
+
batch_size_train=32
|
|
13
|
+
batch_size_test=32
|
|
14
|
+
|
|
15
|
+
[Training]
|
|
16
|
+
train_seconds=30
|
|
17
|
+
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
18
|
+
print_summary=False
|
|
19
|
+
print_progress=False
|
|
@@ -39,9 +39,12 @@ def main(domain, instance, method, episodes=1):
|
|
|
39
39
|
f'using default_{method}.cfg.', 'red')
|
|
40
40
|
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
41
41
|
planner_args, _, train_args = load_config(config_path)
|
|
42
|
+
if 'dashboard' in train_args:
|
|
43
|
+
train_args['dashboard'].launch()
|
|
42
44
|
|
|
43
45
|
# create the planning algorithm
|
|
44
|
-
planner = JaxBackpropPlanner(
|
|
46
|
+
planner = JaxBackpropPlanner(
|
|
47
|
+
rddl=env.model, dashboard_viz=env._visualizer, **planner_args)
|
|
45
48
|
|
|
46
49
|
# evaluate the controller
|
|
47
50
|
if method == 'replan':
|
|
@@ -20,12 +20,19 @@ import os
|
|
|
20
20
|
import sys
|
|
21
21
|
|
|
22
22
|
import pyRDDLGym
|
|
23
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
24
23
|
|
|
25
|
-
from pyRDDLGym_jax.core.tuning import
|
|
26
|
-
|
|
24
|
+
from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
|
|
25
|
+
from pyRDDLGym_jax.core.planner import (
|
|
26
|
+
load_config_from_string, JaxBackpropPlanner,
|
|
27
|
+
JaxOfflineController, JaxOnlineController
|
|
27
28
|
)
|
|
28
|
-
|
|
29
|
+
|
|
30
|
+
def power_2(x):
|
|
31
|
+
return int(2 ** x)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def power_10(x):
|
|
35
|
+
return 10.0 ** x
|
|
29
36
|
|
|
30
37
|
|
|
31
38
|
def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
@@ -35,31 +42,37 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
35
42
|
|
|
36
43
|
# load the config file with planner settings
|
|
37
44
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
38
|
-
config_path = os.path.join(abs_path, 'configs', f'{
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
45
|
+
config_path = os.path.join(abs_path, 'configs', f'tuning_{method}.cfg')
|
|
46
|
+
with open(config_path, 'r') as file:
|
|
47
|
+
config_template = file.read()
|
|
48
|
+
|
|
49
|
+
# map parameters in the config that will be tuned
|
|
50
|
+
hyperparams = [
|
|
51
|
+
Hyperparameter('MODEL_WEIGHT_TUNE', -1., 5., power_10),
|
|
52
|
+
Hyperparameter('POLICY_WEIGHT_TUNE', -2., 2., power_10),
|
|
53
|
+
Hyperparameter('LEARNING_RATE_TUNE', -5., 1., power_10),
|
|
54
|
+
Hyperparameter('LAYER1_TUNE', 1, 8, power_2),
|
|
55
|
+
Hyperparameter('LAYER2_TUNE', 1, 8, power_2),
|
|
56
|
+
Hyperparameter('ROLLOUT_HORIZON_TUNE', 1, min(env.horizon, 100), int)
|
|
57
|
+
]
|
|
44
58
|
|
|
45
|
-
#
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
timeout_training=train_args['train_seconds'],
|
|
55
|
-
eval_trials=trials,
|
|
56
|
-
planner_kwargs=planner_args,
|
|
57
|
-
plan_kwargs=plan_args,
|
|
58
|
-
num_workers=workers,
|
|
59
|
-
gp_iters=iters)
|
|
59
|
+
# build the tuner and tune
|
|
60
|
+
tuning = JaxParameterTuning(env=env,
|
|
61
|
+
config_template=config_template,
|
|
62
|
+
hyperparams=hyperparams,
|
|
63
|
+
online=method == 'replan',
|
|
64
|
+
eval_trials=trials,
|
|
65
|
+
num_workers=workers,
|
|
66
|
+
gp_iters=iters)
|
|
67
|
+
tuning.tune(key=42, log_file=f'gp_{method}_{domain}_{instance}.csv')
|
|
60
68
|
|
|
61
|
-
#
|
|
62
|
-
|
|
69
|
+
# evaluate the agent on the best parameters
|
|
70
|
+
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
71
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
72
|
+
klass = JaxOnlineController if method == 'replan' else JaxOfflineController
|
|
73
|
+
controller = klass(planner, **train_args)
|
|
74
|
+
controller.evaluate(env, episodes=1, verbose=True, render=True)
|
|
75
|
+
env.close()
|
|
63
76
|
|
|
64
77
|
|
|
65
78
|
if __name__ == "__main__":
|