pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [32, 32]}
9
+ method_kwargs={'topology': [32, 16]}
10
10
  optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.005}
12
12
  batch_size_train=1
@@ -14,5 +14,4 @@ batch_size_test=1
14
14
 
15
15
  [Training]
16
16
  key=42
17
- epochs=2000
18
- train_seconds=30
17
+ epochs=1000
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 100}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 50}
4
+ rounding_kwargs={'weight': 50}
5
+ control_kwargs={'weight': 50}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -15,6 +15,5 @@ rollout_horizon=30
15
15
 
16
16
  [Training]
17
17
  key=42
18
- epochs=1000
19
- train_seconds=1
18
+ train_seconds=0.5
20
19
  print_summary=False
@@ -1,19 +1,18 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 30}
4
+ rounding_kwargs={'weight': 30}
5
+ control_kwargs={'weight': 30}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.001}
11
+ optimizer_kwargs={'learning_rate': 0.002}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
  clip_grad=1.0
15
15
 
16
16
  [Training]
17
17
  key=42
18
- epochs=5000
19
- train_seconds=30
18
+ epochs=5000
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 5}
4
+ rounding_kwargs={'weight': 5}
5
+ control_kwargs={'weight': 5}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,14 +1,14 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 5}
4
+ rounding_kwargs={'weight': 5}
5
+ control_kwargs={'weight': 5}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.01}
11
+ optimizer_kwargs={'learning_rate': 0.02}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 100}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,14 +1,14 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 500}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 50}
4
+ rounding_kwargs={'weight': 50}
5
+ control_kwargs={'weight': 50}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.01}
11
+ optimizer_kwargs={'learning_rate': 0.03}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -14,5 +14,5 @@ batch_size_test=32
14
14
 
15
15
  [Training]
16
16
  key=42
17
- epochs=5000
18
- train_seconds=60
17
+ epochs=2000
18
+ train_seconds=30
@@ -1,14 +1,14 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 1.0}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 1}
4
+ rounding_kwargs={'weight': 1}
5
+ control_kwargs={'weight': 1}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.0005}
11
+ optimizer_kwargs={'learning_rate': 0.0003}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 100}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 100}
4
+ rounding_kwargs={'weight': 100}
5
+ control_kwargs={'weight': 100}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 100}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 100}
4
+ rounding_kwargs={'weight': 100}
5
+ control_kwargs={'weight': 100}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 100}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 100}
4
+ rounding_kwargs={'weight': 100}
5
+ control_kwargs={'weight': 100}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 20}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
+ rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
+ control_kwargs={'weight': MODEL_WEIGHT_TUNE}
6
+
7
+ [Optimizer]
8
+ method='JaxDeepReactivePolicy'
9
+ method_kwargs={'topology': [LAYER1_TUNE, LAYER2_TUNE]}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+
15
+ [Training]
16
+ train_seconds=30
17
+ policy_hyperparams=POLICY_WEIGHT_TUNE
18
+ print_summary=False
19
+ print_progress=False
@@ -0,0 +1,20 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
+ rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
+ control_kwargs={'weight': MODEL_WEIGHT_TUNE}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+ rollout_horizon=ROLLOUT_HORIZON_TUNE
15
+
16
+ [Training]
17
+ train_seconds=1
18
+ policy_hyperparams=POLICY_WEIGHT_TUNE
19
+ print_summary=False
20
+ print_progress=False
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ comparison_kwargs={'weight': MODEL_WEIGHT_TUNE}
4
+ rounding_kwargs={'weight': MODEL_WEIGHT_TUNE}
5
+ control_kwargs={'weight': MODEL_WEIGHT_TUNE}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+
15
+ [Training]
16
+ train_seconds=30
17
+ policy_hyperparams=POLICY_WEIGHT_TUNE
18
+ print_summary=False
19
+ print_progress=False
@@ -39,9 +39,12 @@ def main(domain, instance, method, episodes=1):
39
39
  f'using default_{method}.cfg.', 'red')
40
40
  config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
41
41
  planner_args, _, train_args = load_config(config_path)
42
+ if 'dashboard' in train_args:
43
+ train_args['dashboard'].launch()
42
44
 
43
45
  # create the planning algorithm
44
- planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
46
+ planner = JaxBackpropPlanner(
47
+ rddl=env.model, dashboard_viz=env._visualizer, **planner_args)
45
48
 
46
49
  # evaluate the controller
47
50
  if method == 'replan':
@@ -20,12 +20,19 @@ import os
20
20
  import sys
21
21
 
22
22
  import pyRDDLGym
23
- from pyRDDLGym.core.debug.exception import raise_warning
24
23
 
25
- from pyRDDLGym_jax.core.tuning import (
26
- JaxParameterTuningDRP, JaxParameterTuningSLP, JaxParameterTuningSLPReplan
24
+ from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
25
+ from pyRDDLGym_jax.core.planner import (
26
+ load_config_from_string, JaxBackpropPlanner,
27
+ JaxOfflineController, JaxOnlineController
27
28
  )
28
- from pyRDDLGym_jax.core.planner import load_config
29
+
30
+ def power_2(x):
31
+ return int(2 ** x)
32
+
33
+
34
+ def power_10(x):
35
+ return 10.0 ** x
29
36
 
30
37
 
31
38
  def main(domain, instance, method, trials=5, iters=20, workers=4):
@@ -35,31 +42,37 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
35
42
 
36
43
  # load the config file with planner settings
37
44
  abs_path = os.path.dirname(os.path.abspath(__file__))
38
- config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
39
- if not os.path.isfile(config_path):
40
- raise_warning(f'Config file {config_path} was not found, '
41
- f'using default_{method}.cfg.', 'red')
42
- config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
43
- planner_args, plan_args, train_args = load_config(config_path)
45
+ config_path = os.path.join(abs_path, 'configs', f'tuning_{method}.cfg')
46
+ with open(config_path, 'r') as file:
47
+ config_template = file.read()
48
+
49
+ # map parameters in the config that will be tuned
50
+ hyperparams = [
51
+ Hyperparameter('MODEL_WEIGHT_TUNE', -1., 5., power_10),
52
+ Hyperparameter('POLICY_WEIGHT_TUNE', -2., 2., power_10),
53
+ Hyperparameter('LEARNING_RATE_TUNE', -5., 1., power_10),
54
+ Hyperparameter('LAYER1_TUNE', 1, 8, power_2),
55
+ Hyperparameter('LAYER2_TUNE', 1, 8, power_2),
56
+ Hyperparameter('ROLLOUT_HORIZON_TUNE', 1, min(env.horizon, 100), int)
57
+ ]
44
58
 
45
- # define algorithm to perform tuning
46
- if method == 'slp':
47
- tuning_class = JaxParameterTuningSLP
48
- elif method == 'drp':
49
- tuning_class = JaxParameterTuningDRP
50
- elif method == 'replan':
51
- tuning_class = JaxParameterTuningSLPReplan
52
- tuning = tuning_class(env=env,
53
- train_epochs=train_args['epochs'],
54
- timeout_training=train_args['train_seconds'],
55
- eval_trials=trials,
56
- planner_kwargs=planner_args,
57
- plan_kwargs=plan_args,
58
- num_workers=workers,
59
- gp_iters=iters)
59
+ # build the tuner and tune
60
+ tuning = JaxParameterTuning(env=env,
61
+ config_template=config_template,
62
+ hyperparams=hyperparams,
63
+ online=method == 'replan',
64
+ eval_trials=trials,
65
+ num_workers=workers,
66
+ gp_iters=iters)
67
+ tuning.tune(key=42, log_file=f'gp_{method}_{domain}_{instance}.csv')
60
68
 
61
- # perform tuning and report best parameters
62
- tuning.tune(key=train_args['key'], filename=f'gp_{method}', save_plot=True)
69
+ # evaluate the agent on the best parameters
70
+ planner_args, _, train_args = load_config_from_string(tuning.best_config)
71
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
72
+ klass = JaxOnlineController if method == 'replan' else JaxOfflineController
73
+ controller = klass(planner, **train_args)
74
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
75
+ env.close()
63
76
 
64
77
 
65
78
  if __name__ == "__main__":