pyRDDLGym-jax 2.7__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -33
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.7.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 5}
4
- rounding_kwargs={'weight': 5}
5
- control_kwargs={'weight': 5}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ bernoulli_sigmoid_weight=5
4
+ sigmoid_weight=5
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxStraightLinePlan'
9
8
  method_kwargs={}
10
9
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.02}
10
+ optimizer_kwargs={'learning_rate': 0.1}
12
11
  batch_size_train=1
13
12
  batch_size_test=1
14
13
 
15
- [Training]
14
+ [Optimize]
16
15
  key=42
17
16
  epochs=30000
18
- train_seconds=30
17
+ train_seconds=60
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -13,7 +11,7 @@ batch_size_train=1
13
11
  batch_size_test=1
14
12
  clip_grad=1.0
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=1000
19
17
  train_seconds=30
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -13,7 +11,7 @@ batch_size_train=1
13
11
  batch_size_test=1
14
12
  clip_grad=1.0
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=1000
19
17
  train_seconds=30
@@ -1,18 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [256, 128]}
7
+ method_kwargs={'topology': [128, 128]}
10
8
  optimizer='rmsprop'
11
9
  optimizer_kwargs={'learning_rate': 0.0001}
12
10
  batch_size_train=32
13
11
  batch_size_test=32
12
+ pgpe=None
14
13
 
15
- [Training]
14
+ [Optimize]
16
15
  key=42
17
16
  epochs=4000
18
17
  train_seconds=30
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -12,8 +10,9 @@ optimizer_kwargs={'learning_rate': 0.1}
12
10
  batch_size_train=32
13
11
  batch_size_test=32
14
12
  rollout_horizon=5
13
+ pgpe=None
15
14
 
16
- [Training]
15
+ [Optimize]
17
16
  key=42
18
17
  epochs=2000
19
18
  train_seconds=1
@@ -1,18 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
11
9
  optimizer_kwargs={'learning_rate': 0.05}
12
10
  batch_size_train=32
13
11
  batch_size_test=32
12
+ pgpe=None
14
13
 
15
- [Training]
14
+ [Optimize]
16
15
  key=42
17
16
  epochs=10000
18
17
  train_seconds=30
@@ -1,19 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [256, 128], 'activation': 'tanh'}
7
+ method_kwargs={'topology': [128, 128], 'activation': 'tanh'}
10
8
  optimizer='rmsprop'
11
9
  optimizer_kwargs={'learning_rate': 0.001}
12
10
  batch_size_train=1
13
11
  batch_size_test=1
14
12
  pgpe=None
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=100000
19
17
  train_seconds=360
@@ -0,0 +1,17 @@
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
4
+
5
+ [Planner]
6
+ method='JaxDeepReactivePolicy'
7
+ method_kwargs={'topology': [128, 128], 'activation': 'tanh'}
8
+ optimizer='rmsprop'
9
+ optimizer_kwargs={'learning_rate': 0.001}
10
+ batch_size_train=1
11
+ batch_size_test=1
12
+ pgpe=None
13
+
14
+ [Optimize]
15
+ key=42
16
+ epochs=100000
17
+ train_seconds=360
@@ -0,0 +1,17 @@
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
4
+
5
+ [Planner]
6
+ method='JaxStraightLinePlan'
7
+ method_kwargs={}
8
+ optimizer='rmsprop'
9
+ optimizer_kwargs={'learning_rate': 0.03}
10
+ batch_size_train=1
11
+ batch_size_test=1
12
+ pgpe=None
13
+
14
+ [Optimize]
15
+ key=42
16
+ epochs=100000
17
+ train_seconds=360
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=10
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -13,7 +11,7 @@ batch_size_train=1
13
11
  batch_size_test=1
14
12
  pgpe=None
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=100000
19
17
  train_seconds=360
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxDeepReactivePolicy'
9
6
  method_kwargs={'topology': [64, 32]}
10
7
  optimizer='rmsprop'
@@ -13,7 +10,7 @@ batch_size_train=32
13
10
  batch_size_test=32
14
11
  pgpe=None
15
12
 
16
- [Training]
13
+ [Optimize]
17
14
  key=42
18
15
  epochs=5000
19
16
  train_seconds=60
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxStraightLinePlan'
9
6
  method_kwargs={}
10
7
  optimizer='rmsprop'
@@ -12,8 +9,9 @@ optimizer_kwargs={'learning_rate': 0.1}
12
9
  batch_size_train=32
13
10
  batch_size_test=32
14
11
  rollout_horizon=5
12
+ pgpe=None
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=500
19
17
  train_seconds=1
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 10}
4
- rounding_kwargs={'weight': 10}
5
- control_kwargs={'weight': 10}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxStraightLinePlan'
9
6
  method_kwargs={}
10
7
  optimizer='rmsprop'
@@ -13,7 +10,7 @@ batch_size_train=32
13
10
  batch_size_test=32
14
11
  pgpe=None
15
12
 
16
- [Training]
13
+ [Optimize]
17
14
  key=42
18
15
  epochs=2000
19
16
  train_seconds=30
@@ -1,10 +1,8 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 1}
4
- rounding_kwargs={'weight': 1}
5
- control_kwargs={'weight': 1}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=1
6
4
 
7
- [Optimizer]
5
+ [Planner]
8
6
  method='JaxStraightLinePlan'
9
7
  method_kwargs={}
10
8
  optimizer='rmsprop'
@@ -13,7 +11,7 @@ batch_size_train=1
13
11
  batch_size_test=1
14
12
  pgpe=None
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  key=42
18
16
  epochs=30000
19
17
  train_seconds=120
@@ -1,10 +1,9 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 100}
4
- rounding_kwargs={'weight': 100}
5
- control_kwargs={'weight': 100}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=100
4
+ bernoulli_sigmoid_weight=100
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxDeepReactivePolicy'
9
8
  method_kwargs={'topology': [128, 64]}
10
9
  optimizer='rmsprop'
@@ -13,7 +12,7 @@ batch_size_train=32
13
12
  batch_size_test=32
14
13
  pgpe=None
15
14
 
16
- [Training]
15
+ [Optimize]
17
16
  key=42
18
17
  epochs=1000
19
18
  train_seconds=30
@@ -1,10 +1,9 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 100}
4
- rounding_kwargs={'weight': 100}
5
- control_kwargs={'weight': 100}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=100
4
+ bernoulli_sigmoid_weight=100
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxStraightLinePlan'
9
8
  method_kwargs={}
10
9
  optimizer='rmsprop'
@@ -14,7 +13,7 @@ batch_size_test=32
14
13
  rollout_horizon=5
15
14
  pgpe=None
16
15
 
17
- [Training]
16
+ [Optimize]
18
17
  key=42
19
18
  epochs=1000
20
19
  train_seconds=1
@@ -1,10 +1,9 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 100}
4
- rounding_kwargs={'weight': 100}
5
- control_kwargs={'weight': 100}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=100
4
+ bernoulli_sigmoid_weight=100
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxStraightLinePlan'
9
8
  method_kwargs={}
10
9
  optimizer='rmsprop'
@@ -13,7 +12,7 @@ batch_size_train=32
13
12
  batch_size_test=32
14
13
  pgpe=None
15
14
 
16
- [Training]
15
+ [Optimize]
17
16
  key=42
18
17
  epochs=1000
19
18
  train_seconds=30
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxDeepReactivePolicy'
9
6
  method_kwargs={}
10
7
  optimizer='rmsprop'
@@ -12,8 +9,8 @@ optimizer_kwargs={'learning_rate': 0.0001}
12
9
  batch_size_train=32
13
10
  batch_size_test=32
14
11
 
15
- [Training]
12
+ [Optimize]
16
13
  key=42
17
14
  epochs=30000
18
15
  train_seconds=60
19
- policy_hyperparams=2.0
16
+ policy_hyperparams=1.0
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxStraightLinePlan'
9
6
  method_kwargs={}
10
7
  optimizer='rmsprop'
@@ -13,9 +10,9 @@ batch_size_train=32
13
10
  batch_size_test=32
14
11
  rollout_horizon=5
15
12
 
16
- [Training]
13
+ [Optimize]
17
14
  key=42
18
15
  epochs=2000
19
16
  train_seconds=1
20
- policy_hyperparams=2.0
17
+ policy_hyperparams=1.0
21
18
  print_summary=False
@@ -1,10 +1,7 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 20}
4
- rounding_kwargs={'weight': 20}
5
- control_kwargs={'weight': 20}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
6
3
 
7
- [Optimizer]
4
+ [Planner]
8
5
  method='JaxStraightLinePlan'
9
6
  method_kwargs={}
10
7
  optimizer='rmsprop'
@@ -12,8 +9,8 @@ optimizer_kwargs={'learning_rate': 0.01}
12
9
  batch_size_train=32
13
10
  batch_size_test=32
14
11
 
15
- [Training]
12
+ [Optimize]
16
13
  key=42
17
14
  epochs=30000
18
15
  train_seconds=60
19
- policy_hyperparams=2.0
16
+ policy_hyperparams=1.0
@@ -1,19 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
- rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
- control_kwargs={'weight': MODEL_WEIGHT_TUNE}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=MODEL_WEIGHT_TUNE
4
+ print_warnings=False
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxDeepReactivePolicy'
9
8
  method_kwargs={'topology': [LAYER1_TUNE, LAYER2_TUNE]}
10
9
  optimizer='rmsprop'
11
10
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
11
  batch_size_train=32
13
12
  batch_size_test=32
14
- print_warnings=False
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  train_seconds=30
18
16
  policy_hyperparams=POLICY_WEIGHT_TUNE
19
17
  print_summary=False
@@ -1,10 +1,9 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
- rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
- control_kwargs={'weight': MODEL_WEIGHT_TUNE}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=MODEL_WEIGHT_TUNE
4
+ print_warnings=False
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxStraightLinePlan'
9
8
  method_kwargs={}
10
9
  optimizer='rmsprop'
@@ -12,9 +11,8 @@ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
11
  batch_size_train=32
13
12
  batch_size_test=32
14
13
  rollout_horizon=ROLLOUT_HORIZON_TUNE
15
- print_warnings=False
16
14
 
17
- [Training]
15
+ [Optimize]
18
16
  train_seconds=1
19
17
  policy_hyperparams=POLICY_WEIGHT_TUNE
20
18
  print_summary=False
@@ -1,19 +1,17 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
- rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
- control_kwargs={'weight': MODEL_WEIGHT_TUNE}
1
+ [Compiler]
2
+ method='DefaultJaxRDDLCompilerWithGrad'
3
+ sigmoid_weight=MODEL_WEIGHT_TUNE
4
+ print_warnings=False
6
5
 
7
- [Optimizer]
6
+ [Planner]
8
7
  method='JaxStraightLinePlan'
9
8
  method_kwargs={}
10
9
  optimizer='rmsprop'
11
10
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
11
  batch_size_train=32
13
12
  batch_size_test=32
14
- print_warnings=False
15
13
 
16
- [Training]
14
+ [Optimize]
17
15
  train_seconds=30
18
16
  policy_hyperparams=POLICY_WEIGHT_TUNE
19
17
  print_summary=False
@@ -25,36 +25,6 @@ from pyRDDLGym_jax.core.planner import (
25
25
  load_config, JaxBackpropPlanner, JaxOfflineController, JaxOnlineController
26
26
  )
27
27
 
28
-
29
- def run_cnn1d():
30
- import haiku as hk
31
- import jax
32
- import jax.numpy as jnp
33
-
34
- class CNN(hk.Module):
35
- def __init__(self, name=None):
36
- super().__init__(name=name)
37
- self.conv1d_layer = hk.Conv1D(
38
- output_channels=4,
39
- kernel_shape=6, # Kernel size for 1D convolution
40
- padding="SAME",
41
- name="conv"
42
- )
43
-
44
- def __call__(self, x):
45
- return self.conv1d_layer(x)
46
-
47
- # Example usage:
48
- key = jax.random.PRNGKey(42)
49
- input_data = jnp.ones([1, 4]) # Batch size 1, sequence length 10, 1 input channel
50
-
51
- # Transform the Haiku module into a pure function
52
- f = hk.transform(lambda x: CNN()(x))
53
- params = f.init(key, input_data)
54
- print(params['cnn/~/conv']['w'].shape)
55
- print(params['cnn/~/conv']['b'].shape)
56
- print(f.apply(params, key, input_data).shape)
57
-
58
28
 
59
29
  def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
60
30
 
@@ -76,8 +46,8 @@ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
76
46
  exit(1)
77
47
 
78
48
  planner_args, _, train_args = load_config(config_path)
79
- if 'dashboard' in train_args:
80
- train_args['dashboard'].launch()
49
+ if 'dashboard' in planner_args:
50
+ planner_args['dashboard'].launch()
81
51
 
82
52
  # create the planning algorithm
83
53
  planner = JaxBackpropPlanner(
@@ -93,7 +63,6 @@ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
93
63
 
94
64
 
95
65
  def run_from_args(args):
96
- run_cnn1d()
97
66
  if len(args) < 3:
98
67
  print('python run_plan.py <domain> <instance> <method> [<episodes>]')
99
68
  exit(1)
@@ -77,8 +77,8 @@ def main(domain: str, instance: str, method: str,
77
77
  # evaluate the agent on the best parameters
78
78
  planner_args, _, train_args = load_config_from_string(tuning.best_config)
79
79
  planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
80
- klass = JaxOnlineController if method == 'replan' else JaxOfflineController
81
- controller = klass(planner, **train_args)
80
+ class_ = JaxOnlineController if method == 'replan' else JaxOfflineController
81
+ controller = class_(planner, **train_args)
82
82
  controller.evaluate(env, episodes=1, verbose=True, render=True)
83
83
  env.close()
84
84