pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': 5}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
bernoulli_sigmoid_weight=5
|
|
4
|
+
sigmoid_weight=5
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxStraightLinePlan'
|
|
9
8
|
method_kwargs={}
|
|
10
9
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
10
|
+
optimizer_kwargs={'learning_rate': 0.1}
|
|
12
11
|
batch_size_train=1
|
|
13
12
|
batch_size_test=1
|
|
14
13
|
|
|
15
|
-
[
|
|
14
|
+
[Optimize]
|
|
16
15
|
key=42
|
|
17
16
|
epochs=30000
|
|
18
|
-
train_seconds=
|
|
17
|
+
train_seconds=60
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -13,7 +11,7 @@ batch_size_train=1
|
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
clip_grad=1.0
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=1000
|
|
19
17
|
train_seconds=30
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -13,7 +11,7 @@ batch_size_train=1
|
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
clip_grad=1.0
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=1000
|
|
19
17
|
train_seconds=30
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [
|
|
7
|
+
method_kwargs={'topology': [128, 128]}
|
|
10
8
|
optimizer='rmsprop'
|
|
11
9
|
optimizer_kwargs={'learning_rate': 0.0001}
|
|
12
10
|
batch_size_train=32
|
|
13
11
|
batch_size_test=32
|
|
12
|
+
pgpe=None
|
|
14
13
|
|
|
15
|
-
[
|
|
14
|
+
[Optimize]
|
|
16
15
|
key=42
|
|
17
16
|
epochs=4000
|
|
18
17
|
train_seconds=30
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -12,8 +10,9 @@ optimizer_kwargs={'learning_rate': 0.1}
|
|
|
12
10
|
batch_size_train=32
|
|
13
11
|
batch_size_test=32
|
|
14
12
|
rollout_horizon=5
|
|
13
|
+
pgpe=None
|
|
15
14
|
|
|
16
|
-
[
|
|
15
|
+
[Optimize]
|
|
17
16
|
key=42
|
|
18
17
|
epochs=2000
|
|
19
18
|
train_seconds=1
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
11
9
|
optimizer_kwargs={'learning_rate': 0.05}
|
|
12
10
|
batch_size_train=32
|
|
13
11
|
batch_size_test=32
|
|
12
|
+
pgpe=None
|
|
14
13
|
|
|
15
|
-
[
|
|
14
|
+
[Optimize]
|
|
16
15
|
key=42
|
|
17
16
|
epochs=10000
|
|
18
17
|
train_seconds=30
|
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxDeepReactivePolicy'
|
|
9
|
-
method_kwargs={'topology': [
|
|
7
|
+
method_kwargs={'topology': [128, 128], 'activation': 'tanh'}
|
|
10
8
|
optimizer='rmsprop'
|
|
11
9
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
10
|
batch_size_train=1
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
pgpe=None
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=100000
|
|
19
17
|
train_seconds=360
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
4
|
+
|
|
5
|
+
[Planner]
|
|
6
|
+
method='JaxDeepReactivePolicy'
|
|
7
|
+
method_kwargs={'topology': [128, 128], 'activation': 'tanh'}
|
|
8
|
+
optimizer='rmsprop'
|
|
9
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
10
|
+
batch_size_train=1
|
|
11
|
+
batch_size_test=1
|
|
12
|
+
pgpe=None
|
|
13
|
+
|
|
14
|
+
[Optimize]
|
|
15
|
+
key=42
|
|
16
|
+
epochs=100000
|
|
17
|
+
train_seconds=360
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
4
|
+
|
|
5
|
+
[Planner]
|
|
6
|
+
method='JaxStraightLinePlan'
|
|
7
|
+
method_kwargs={}
|
|
8
|
+
optimizer='rmsprop'
|
|
9
|
+
optimizer_kwargs={'learning_rate': 0.03}
|
|
10
|
+
batch_size_train=1
|
|
11
|
+
batch_size_test=1
|
|
12
|
+
pgpe=None
|
|
13
|
+
|
|
14
|
+
[Optimize]
|
|
15
|
+
key=42
|
|
16
|
+
epochs=100000
|
|
17
|
+
train_seconds=360
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=10
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -13,7 +11,7 @@ batch_size_train=1
|
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
pgpe=None
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=100000
|
|
19
17
|
train_seconds=360
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 10}
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxDeepReactivePolicy'
|
|
9
6
|
method_kwargs={'topology': [64, 32]}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -13,7 +10,7 @@ batch_size_train=32
|
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
pgpe=None
|
|
15
12
|
|
|
16
|
-
[
|
|
13
|
+
[Optimize]
|
|
17
14
|
key=42
|
|
18
15
|
epochs=5000
|
|
19
16
|
train_seconds=60
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 10}
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxStraightLinePlan'
|
|
9
6
|
method_kwargs={}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -12,8 +9,9 @@ optimizer_kwargs={'learning_rate': 0.1}
|
|
|
12
9
|
batch_size_train=32
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
rollout_horizon=5
|
|
12
|
+
pgpe=None
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=500
|
|
19
17
|
train_seconds=1
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 10}
|
|
4
|
-
rounding_kwargs={'weight': 10}
|
|
5
|
-
control_kwargs={'weight': 10}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxStraightLinePlan'
|
|
9
6
|
method_kwargs={}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -13,7 +10,7 @@ batch_size_train=32
|
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
pgpe=None
|
|
15
12
|
|
|
16
|
-
[
|
|
13
|
+
[Optimize]
|
|
17
14
|
key=42
|
|
18
15
|
epochs=2000
|
|
19
16
|
train_seconds=30
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
rounding_kwargs={'weight': 1}
|
|
5
|
-
control_kwargs={'weight': 1}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=1
|
|
6
4
|
|
|
7
|
-
[
|
|
5
|
+
[Planner]
|
|
8
6
|
method='JaxStraightLinePlan'
|
|
9
7
|
method_kwargs={}
|
|
10
8
|
optimizer='rmsprop'
|
|
@@ -13,7 +11,7 @@ batch_size_train=1
|
|
|
13
11
|
batch_size_test=1
|
|
14
12
|
pgpe=None
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
key=42
|
|
18
16
|
epochs=30000
|
|
19
17
|
train_seconds=120
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': 100}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=100
|
|
4
|
+
bernoulli_sigmoid_weight=100
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxDeepReactivePolicy'
|
|
9
8
|
method_kwargs={'topology': [128, 64]}
|
|
10
9
|
optimizer='rmsprop'
|
|
@@ -13,7 +12,7 @@ batch_size_train=32
|
|
|
13
12
|
batch_size_test=32
|
|
14
13
|
pgpe=None
|
|
15
14
|
|
|
16
|
-
[
|
|
15
|
+
[Optimize]
|
|
17
16
|
key=42
|
|
18
17
|
epochs=1000
|
|
19
18
|
train_seconds=30
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': 100}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=100
|
|
4
|
+
bernoulli_sigmoid_weight=100
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxStraightLinePlan'
|
|
9
8
|
method_kwargs={}
|
|
10
9
|
optimizer='rmsprop'
|
|
@@ -14,7 +13,7 @@ batch_size_test=32
|
|
|
14
13
|
rollout_horizon=5
|
|
15
14
|
pgpe=None
|
|
16
15
|
|
|
17
|
-
[
|
|
16
|
+
[Optimize]
|
|
18
17
|
key=42
|
|
19
18
|
epochs=1000
|
|
20
19
|
train_seconds=1
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': 100}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=100
|
|
4
|
+
bernoulli_sigmoid_weight=100
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxStraightLinePlan'
|
|
9
8
|
method_kwargs={}
|
|
10
9
|
optimizer='rmsprop'
|
|
@@ -13,7 +12,7 @@ batch_size_train=32
|
|
|
13
12
|
batch_size_test=32
|
|
14
13
|
pgpe=None
|
|
15
14
|
|
|
16
|
-
[
|
|
15
|
+
[Optimize]
|
|
17
16
|
key=42
|
|
18
17
|
epochs=1000
|
|
19
18
|
train_seconds=30
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 20}
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxDeepReactivePolicy'
|
|
9
6
|
method_kwargs={}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -12,8 +9,8 @@ optimizer_kwargs={'learning_rate': 0.0001}
|
|
|
12
9
|
batch_size_train=32
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
|
|
15
|
-
[
|
|
12
|
+
[Optimize]
|
|
16
13
|
key=42
|
|
17
14
|
epochs=30000
|
|
18
15
|
train_seconds=60
|
|
19
|
-
policy_hyperparams=
|
|
16
|
+
policy_hyperparams=1.0
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 20}
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxStraightLinePlan'
|
|
9
6
|
method_kwargs={}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -13,9 +10,9 @@ batch_size_train=32
|
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
rollout_horizon=5
|
|
15
12
|
|
|
16
|
-
[
|
|
13
|
+
[Optimize]
|
|
17
14
|
key=42
|
|
18
15
|
epochs=2000
|
|
19
16
|
train_seconds=1
|
|
20
|
-
policy_hyperparams=
|
|
17
|
+
policy_hyperparams=1.0
|
|
21
18
|
print_summary=False
|
|
@@ -1,10 +1,7 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
comparison_kwargs={'weight': 20}
|
|
4
|
-
rounding_kwargs={'weight': 20}
|
|
5
|
-
control_kwargs={'weight': 20}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
6
3
|
|
|
7
|
-
[
|
|
4
|
+
[Planner]
|
|
8
5
|
method='JaxStraightLinePlan'
|
|
9
6
|
method_kwargs={}
|
|
10
7
|
optimizer='rmsprop'
|
|
@@ -12,8 +9,8 @@ optimizer_kwargs={'learning_rate': 0.01}
|
|
|
12
9
|
batch_size_train=32
|
|
13
10
|
batch_size_test=32
|
|
14
11
|
|
|
15
|
-
[
|
|
12
|
+
[Optimize]
|
|
16
13
|
key=42
|
|
17
14
|
epochs=30000
|
|
18
15
|
train_seconds=60
|
|
19
|
-
policy_hyperparams=
|
|
16
|
+
policy_hyperparams=1.0
|
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=MODEL_WEIGHT_TUNE
|
|
4
|
+
print_warnings=False
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxDeepReactivePolicy'
|
|
9
8
|
method_kwargs={'topology': [LAYER1_TUNE, LAYER2_TUNE]}
|
|
10
9
|
optimizer='rmsprop'
|
|
11
10
|
optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
12
11
|
batch_size_train=32
|
|
13
12
|
batch_size_test=32
|
|
14
|
-
print_warnings=False
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
train_seconds=30
|
|
18
16
|
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
19
17
|
print_summary=False
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=MODEL_WEIGHT_TUNE
|
|
4
|
+
print_warnings=False
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxStraightLinePlan'
|
|
9
8
|
method_kwargs={}
|
|
10
9
|
optimizer='rmsprop'
|
|
@@ -12,9 +11,8 @@ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
|
12
11
|
batch_size_train=32
|
|
13
12
|
batch_size_test=32
|
|
14
13
|
rollout_horizon=ROLLOUT_HORIZON_TUNE
|
|
15
|
-
print_warnings=False
|
|
16
14
|
|
|
17
|
-
[
|
|
15
|
+
[Optimize]
|
|
18
16
|
train_seconds=1
|
|
19
17
|
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
20
18
|
print_summary=False
|
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
control_kwargs={'weight': MODEL_WEIGHT_TUNE}
|
|
1
|
+
[Compiler]
|
|
2
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
3
|
+
sigmoid_weight=MODEL_WEIGHT_TUNE
|
|
4
|
+
print_warnings=False
|
|
6
5
|
|
|
7
|
-
[
|
|
6
|
+
[Planner]
|
|
8
7
|
method='JaxStraightLinePlan'
|
|
9
8
|
method_kwargs={}
|
|
10
9
|
optimizer='rmsprop'
|
|
11
10
|
optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
|
|
12
11
|
batch_size_train=32
|
|
13
12
|
batch_size_test=32
|
|
14
|
-
print_warnings=False
|
|
15
13
|
|
|
16
|
-
[
|
|
14
|
+
[Optimize]
|
|
17
15
|
train_seconds=30
|
|
18
16
|
policy_hyperparams=POLICY_WEIGHT_TUNE
|
|
19
17
|
print_summary=False
|
|
@@ -46,8 +46,8 @@ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
|
46
46
|
exit(1)
|
|
47
47
|
|
|
48
48
|
planner_args, _, train_args = load_config(config_path)
|
|
49
|
-
if 'dashboard' in
|
|
50
|
-
|
|
49
|
+
if 'dashboard' in planner_args:
|
|
50
|
+
planner_args['dashboard'].launch()
|
|
51
51
|
|
|
52
52
|
# create the planning algorithm
|
|
53
53
|
planner = JaxBackpropPlanner(
|
|
@@ -77,8 +77,8 @@ def main(domain: str, instance: str, method: str,
|
|
|
77
77
|
# evaluate the agent on the best parameters
|
|
78
78
|
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
79
79
|
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
80
|
-
|
|
81
|
-
controller =
|
|
80
|
+
class_ = JaxOnlineController if method == 'replan' else JaxOfflineController
|
|
81
|
+
controller = class_(planner, **train_args)
|
|
82
82
|
controller.evaluate(env, episodes=1, verbose=True, render=True)
|
|
83
83
|
env.close()
|
|
84
84
|
|