pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. pyRDDLGym_jax/__init__.py +1 -0
  2. pyRDDLGym_jax/core/compiler.py +444 -221
  3. pyRDDLGym_jax/core/logic.py +129 -62
  4. pyRDDLGym_jax/core/planner.py +965 -394
  5. pyRDDLGym_jax/core/simulator.py +5 -7
  6. pyRDDLGym_jax/core/tuning.py +29 -15
  7. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  8. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
  9. pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
  10. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  11. pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
  12. pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
  13. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  14. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  15. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  16. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  17. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  18. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  19. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  20. pyRDDLGym_jax/examples/run_gym.py +3 -7
  21. pyRDDLGym_jax/examples/run_plan.py +10 -5
  22. pyRDDLGym_jax/examples/run_scipy.py +61 -0
  23. pyRDDLGym_jax/examples/run_tune.py +8 -3
  24. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
  25. pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
  26. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
  27. pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
  28. pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
  29. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  30. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  34. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  35. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  36. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  37. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  38. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
  39. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import jax
2
2
  import time
3
- from typing import Dict
3
+ from typing import Dict, Optional
4
4
 
5
5
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
6
6
  from pyRDDLGym.core.debug.exception import (
@@ -20,9 +20,9 @@ Args = Dict[str, Value]
20
20
  class JaxRDDLSimulator(RDDLSimulator):
21
21
 
22
22
  def __init__(self, rddl: RDDLLiftedModel,
23
- key: jax.random.PRNGKey=None,
23
+ key: Optional[jax.random.PRNGKey]=None,
24
24
  raise_error: bool=True,
25
- logger: Logger=None,
25
+ logger: Optional[Logger]=None,
26
26
  keep_tensors: bool=False,
27
27
  **compiler_args) -> None:
28
28
  '''Creates a new simulator for the given RDDL model with Jax as a backend.
@@ -56,10 +56,8 @@ class JaxRDDLSimulator(RDDLSimulator):
56
56
  rddl = self.rddl
57
57
 
58
58
  # compilation
59
- if self.logger is not None:
60
- self.logger.clear()
61
59
  compiled = JaxRDDLCompiler(rddl, logger=self.logger, **self.compiler_args)
62
- compiled.compile(log_jax_expr=True)
60
+ compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
63
61
 
64
62
  self.init_values = compiled.init_values
65
63
  self.levels = compiled.levels
@@ -96,7 +94,7 @@ class JaxRDDLSimulator(RDDLSimulator):
96
94
  self.precond_names = [f'Precondition {i}' for i in range(len(rddl.preconditions))]
97
95
  self.terminal_names = [f'Termination {i}' for i in range(len(rddl.terminations))]
98
96
 
99
- def handle_error_code(self, error, msg) -> None:
97
+ def handle_error_code(self, error: int, msg: str) -> None:
100
98
  if self.raise_error:
101
99
  errors = JaxRDDLCompiler.get_error_messages(error)
102
100
  if errors:
@@ -8,7 +8,9 @@ from multiprocessing import get_context
8
8
  import numpy as np
9
9
  import os
10
10
  import time
11
- from typing import Callable, Dict, Tuple
11
+ from typing import Any, Callable, Dict, Optional, Tuple
12
+
13
+ Kwargs = Dict[str, Any]
12
14
 
13
15
  import warnings
14
16
  warnings.filterwarnings("ignore")
@@ -45,15 +47,15 @@ class JaxParameterTuning:
45
47
  timeout_tuning: float=np.inf,
46
48
  eval_trials: int=5,
47
49
  verbose: bool=True,
48
- planner_kwargs: Dict={},
49
- plan_kwargs: Dict={},
50
+ planner_kwargs: Optional[Kwargs]=None,
51
+ plan_kwargs: Optional[Kwargs]=None,
50
52
  pool_context: str='spawn',
51
53
  num_workers: int=1,
52
54
  poll_frequency: float=0.2,
53
55
  gp_iters: int=25,
54
- acquisition=None,
55
- gp_init_kwargs: Dict={},
56
- gp_params: Dict={'n_restarts_optimizer': 10}) -> None:
56
+ acquisition: Optional[UtilityFunction]=None,
57
+ gp_init_kwargs: Optional[Kwargs]=None,
58
+ gp_params: Optional[Kwargs]=None) -> None:
57
59
  '''Creates a new instance for tuning hyper-parameters for Jax planners
58
60
  on the given RDDL domain and instance.
59
61
 
@@ -93,13 +95,21 @@ class JaxParameterTuning:
93
95
  self.timeout_tuning = timeout_tuning
94
96
  self.eval_trials = eval_trials
95
97
  self.verbose = verbose
98
+ if planner_kwargs is None:
99
+ planner_kwargs = {}
96
100
  self.planner_kwargs = planner_kwargs
101
+ if plan_kwargs is None:
102
+ plan_kwargs = {}
97
103
  self.plan_kwargs = plan_kwargs
98
104
  self.pool_context = pool_context
99
105
  self.num_workers = num_workers
100
106
  self.poll_frequency = poll_frequency
101
107
  self.gp_iters = gp_iters
108
+ if gp_init_kwargs is None:
109
+ gp_init_kwargs = {}
102
110
  self.gp_init_kwargs = gp_init_kwargs
111
+ if gp_params is None:
112
+ gp_params = {'n_restarts_optimizer': 10}
103
113
  self.gp_params = gp_params
104
114
 
105
115
  # create acquisition function
@@ -109,7 +119,7 @@ class JaxParameterTuning:
109
119
  acquisition, self.acq_args = JaxParameterTuning._annealing_utility(num_samples)
110
120
  self.acquisition = acquisition
111
121
 
112
- def summarize_hyperparameters(self):
122
+ def summarize_hyperparameters(self) -> None:
113
123
  print(f'hyperparameter optimizer parameters:\n'
114
124
  f' tuned_hyper_parameters ={self.hyperparams_dict}\n'
115
125
  f' initialization_args ={self.gp_init_kwargs}\n'
@@ -150,8 +160,9 @@ class JaxParameterTuning:
150
160
  pid = os.getpid()
151
161
  return index, pid, params, target
152
162
 
153
- def tune(self, key: jax.random.PRNGKey, filename: str,
154
- save_plot: bool=False) -> Dict[str, object]:
163
+ def tune(self, key: jax.random.PRNGKey,
164
+ filename: str,
165
+ save_plot: bool=False) -> Dict[str, Any]:
155
166
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
156
167
  self.summarize_hyperparameters()
157
168
 
@@ -357,14 +368,15 @@ def objective_slp(params, kwargs, key, index):
357
368
  train_seconds=kwargs['timeout_training'],
358
369
  model_params=model_params,
359
370
  policy_hyperparams=policy_hparams,
360
- verbose=0,
371
+ print_summary=False,
372
+ print_progress=False,
361
373
  tqdm_position=index)
362
374
 
363
375
  # initialize env for evaluation (need fresh copy to avoid concurrency)
364
376
  env = RDDLEnv(domain=kwargs['domain'],
365
377
  instance=kwargs['instance'],
366
378
  vectorized=True,
367
- enforce_action_constraints=True)
379
+ enforce_action_constraints=False)
368
380
 
369
381
  # perform training
370
382
  average_reward = 0.0
@@ -488,14 +500,15 @@ def objective_replan(params, kwargs, key, index):
488
500
  train_seconds=kwargs['timeout_training'],
489
501
  model_params=model_params,
490
502
  policy_hyperparams=policy_hparams,
491
- verbose=0,
503
+ print_summary=False,
504
+ print_progress=False,
492
505
  tqdm_position=index)
493
506
 
494
507
  # initialize env for evaluation (need fresh copy to avoid concurrency)
495
508
  env = RDDLEnv(domain=kwargs['domain'],
496
509
  instance=kwargs['instance'],
497
510
  vectorized=True,
498
- enforce_action_constraints=True)
511
+ enforce_action_constraints=False)
499
512
 
500
513
  # perform training
501
514
  average_reward = 0.0
@@ -615,14 +628,15 @@ def objective_drp(params, kwargs, key, index):
615
628
  train_seconds=kwargs['timeout_training'],
616
629
  model_params=model_params,
617
630
  policy_hyperparams=policy_hparams,
618
- verbose=0,
631
+ print_summary=False,
632
+ print_progress=False,
619
633
  tqdm_position=index)
620
634
 
621
635
  # initialize env for evaluation (need fresh copy to avoid concurrency)
622
636
  env = RDDLEnv(domain=kwargs['domain'],
623
637
  instance=kwargs['instance'],
624
638
  vectorized=True,
625
- enforce_action_constraints=True)
639
+ enforce_action_constraints=False)
626
640
 
627
641
  # perform training
628
642
  average_reward = 0.0
@@ -8,12 +8,11 @@ tnorm_kwargs={}
8
8
  method='JaxDeepReactivePolicy'
9
9
  method_kwargs={'topology': [32, 32]}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.001}
11
+ optimizer_kwargs={'learning_rate': 0.005}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
- clip_grad=1.0
15
14
 
16
15
  [Training]
17
16
  key=42
18
- epochs=5000
17
+ epochs=2000
19
18
  train_seconds=30
@@ -6,13 +6,13 @@ tnorm_kwargs={}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [128, 128]}
9
+ method_kwargs={'topology': [64, 64]}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.0003}
11
+ optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
15
15
  [Training]
16
16
  key=42
17
- epochs=2000
18
- train_seconds=30
17
+ epochs=6000
18
+ train_seconds=60
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.01}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
+ action_bounds={'power-x': (-0.09999, 0.09999), 'power-y': (-0.09999, 0.09999)}
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ logic_kwargs={'weight': 10}
4
+ tnorm='ProductTNorm'
5
+ tnorm_kwargs={}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 1.0}
12
+ batch_size_train=1
13
+ batch_size_test=1
14
+ clip_grad=1.0
15
+
16
+ [Training]
17
+ key=42
18
+ epochs=1000
19
+ train_seconds=30
@@ -8,7 +8,7 @@ tnorm_kwargs={}
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 5.0}
11
+ optimizer_kwargs={'learning_rate': 1.0}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
@@ -1,12 +1,12 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- logic_kwargs={'weight': 1.0}
3
+ logic_kwargs={'weight': 100}
4
4
  tnorm='ProductTNorm'
5
5
  tnorm_kwargs={}
6
6
 
7
7
  [Optimizer]
8
- method='JaxStraightLinePlan'
9
- method_kwargs={}
8
+ method='JaxDeepReactivePolicy'
9
+ method_kwargs={'topology': [256, 128], 'activation': 'tanh'}
10
10
  optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=1
@@ -14,5 +14,5 @@ batch_size_test=1
14
14
 
15
15
  [Training]
16
16
  key=42
17
- epochs=2000
18
- train_seconds=30
17
+ epochs=100000
18
+ train_seconds=360
@@ -0,0 +1,18 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ logic_kwargs={'weight': 10}
4
+ tnorm='ProductTNorm'
5
+ tnorm_kwargs={}
6
+
7
+ [Optimizer]
8
+ method='JaxDeepReactivePolicy'
9
+ method_kwargs={'topology': [64, 32]}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 0.0002}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+
15
+ [Training]
16
+ key=42
17
+ epochs=5000
18
+ train_seconds=60
@@ -15,4 +15,4 @@ batch_size_test=32
15
15
  [Training]
16
16
  key=42
17
17
  epochs=5000
18
- train_seconds=30
18
+ train_seconds=60
@@ -7,7 +7,7 @@ tnorm_kwargs={}
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
- optimizer='adam'
10
+ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.0005}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ logic_kwargs={'weight': 20}
4
+ tnorm='ProductTNorm'
5
+ tnorm_kwargs={}
6
+
7
+ [Optimizer]
8
+ method='JaxDeepReactivePolicy'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 0.0001}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+
15
+ [Training]
16
+ key=42
17
+ epochs=30000
18
+ train_seconds=60
19
+ policy_hyperparams=2.0
@@ -0,0 +1,20 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ logic_kwargs={'weight': 20}
4
+ tnorm='ProductTNorm'
5
+ tnorm_kwargs={}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 0.01}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+ rollout_horizon=5
15
+
16
+ [Training]
17
+ key=42
18
+ epochs=2000
19
+ train_seconds=1
20
+ policy_hyperparams=2.0
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ logic_kwargs={'weight': 20}
4
+ tnorm='ProductTNorm'
5
+ tnorm_kwargs={}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 0.01}
12
+ batch_size_train=32
13
+ batch_size_test=32
14
+
15
+ [Training]
16
+ key=42
17
+ epochs=30000
18
+ train_seconds=60
19
+ policy_hyperparams=2.0
@@ -91,7 +91,7 @@ def main():
91
91
  my_args = [jax.random.PRNGKey(42), params, None, subs, compiler.model_params]
92
92
 
93
93
  # print the fluents over the trajectory, return and gradient
94
- print(step_fn(*my_args)['pvar'])
94
+ print(step_fn(*my_args)['fluents'])
95
95
  print(sum_of_rewards(*my_args))
96
96
  print(jax.grad(sum_of_rewards, argnums=1)(*my_args))
97
97
 
@@ -23,17 +23,13 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
23
23
  def main(domain, instance, episodes=1, seed=42):
24
24
 
25
25
  # create the environment
26
- env = pyRDDLGym.make(domain, instance, enforce_action_constraints=True,
27
- backend=JaxRDDLSimulator)
28
- env.seed(seed)
26
+ env = pyRDDLGym.make(domain, instance, backend=JaxRDDLSimulator)
29
27
 
30
- # set up a random policy
28
+ # evaluate a random policy
31
29
  agent = RandomAgent(action_space=env.action_space,
32
30
  num_actions=env.max_allowed_actions,
33
31
  seed=seed)
34
- agent.evaluate(env, episodes=episodes, verbose=True, render=True)
35
-
36
- # important when logging to save all traces
32
+ agent.evaluate(env, episodes=episodes, verbose=True, render=True, seed=seed)
37
33
  env.close()
38
34
 
39
35
 
@@ -13,11 +13,14 @@ where:
13
13
  <domain> is the name of a domain located in the /Examples directory
14
14
  <instance> is the instance number
15
15
  <method> is either slp, drp, or replan
16
+ <episodes> is the optional number of evaluation rollouts
16
17
  '''
17
18
  import os
18
19
  import sys
19
20
 
20
21
  import pyRDDLGym
22
+ from pyRDDLGym.core.debug.exception import raise_warning
23
+
21
24
  from pyRDDLGym_jax.core.planner import (
22
25
  load_config, JaxBackpropPlanner, JaxOfflineController, JaxOnlineController
23
26
  )
@@ -26,24 +29,26 @@ from pyRDDLGym_jax.core.planner import (
26
29
  def main(domain, instance, method, episodes=1):
27
30
 
28
31
  # set up the environment
29
- env = pyRDDLGym.make(domain, instance, vectorized=True, enforce_action_constraints=True)
32
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
30
33
 
31
34
  # load the config file with planner settings
32
35
  abs_path = os.path.dirname(os.path.abspath(__file__))
33
36
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
37
+ if not os.path.isfile(config_path):
38
+ raise_warning(f'Config file {config_path} was not found, '
39
+ f'using default_{method}.cfg.', 'red')
40
+ config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
34
41
  planner_args, _, train_args = load_config(config_path)
35
42
 
36
43
  # create the planning algorithm
37
44
  planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
38
45
 
39
- # create the controller
46
+ # evaluate the controller
40
47
  if method == 'replan':
41
48
  controller = JaxOnlineController(planner, **train_args)
42
49
  else:
43
- controller = JaxOfflineController(planner, **train_args)
44
-
50
+ controller = JaxOfflineController(planner, **train_args)
45
51
  controller.evaluate(env, episodes=episodes, verbose=True, render=True)
46
-
47
52
  env.close()
48
53
 
49
54
 
@@ -0,0 +1,61 @@
1
+ '''In this example, the user has the choice to run the Jax planner using an
2
+ optimizer from scipy.minimize.
3
+
4
+ The syntax for running this example is:
5
+
6
+ python run_scipy.py <domain> <instance> <method> [<episodes>]
7
+
8
+ where:
9
+ <domain> is the name of a domain located in the /Examples directory
10
+ <instance> is the instance number
11
+ <method> is the name of a method provided to scipy.optimize.minimize()
12
+ <episodes> is the optional number of evaluation rollouts
13
+ '''
14
+ import os
15
+ import sys
16
+ import jax
17
+ from scipy.optimize import minimize
18
+
19
+ import pyRDDLGym
20
+ from pyRDDLGym.core.debug.exception import raise_warning
21
+
22
+ from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
23
+
24
+
25
+ def main(domain, instance, method, episodes=1):
26
+
27
+ # set up the environment
28
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
29
+
30
+ # load the config file with planner settings
31
+ abs_path = os.path.dirname(os.path.abspath(__file__))
32
+ config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
33
+ if not os.path.isfile(config_path):
34
+ raise_warning(f'Config file {config_path} was not found, '
35
+ f'using default_slp.cfg.', 'red')
36
+ config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
37
+ planner_args, _, train_args = load_config(config_path)
38
+
39
+ # create the planning algorithm
40
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
41
+
42
+ # find the optimal plan
43
+ loss_fn, grad_fn, guess, unravel_fn = planner.as_optimization_problem()
44
+ opt = minimize(loss_fn, jac=grad_fn, x0=guess, method=method, options={'disp': True})
45
+ params = unravel_fn(opt.x)
46
+
47
+ # evaluate the optimal plan
48
+ controller = JaxOfflineController(planner, params=params, **train_args)
49
+ controller.evaluate(env, episodes=episodes, verbose=True, render=True)
50
+ env.close()
51
+
52
+
53
+ if __name__ == "__main__":
54
+ args = sys.argv[1:]
55
+ if len(args) < 3:
56
+ print('python run_scipy.py <domain> <instance> <method> [<episodes>]')
57
+ exit(1)
58
+ kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
59
+ if len(args) >= 4: kwargs['episodes'] = int(args[3])
60
+ main(**kwargs)
61
+
@@ -20,6 +20,7 @@ import os
20
20
  import sys
21
21
 
22
22
  import pyRDDLGym
23
+ from pyRDDLGym.core.debug.exception import raise_warning
23
24
 
24
25
  from pyRDDLGym_jax.core.tuning import (
25
26
  JaxParameterTuningDRP, JaxParameterTuningSLP, JaxParameterTuningSLPReplan
@@ -30,11 +31,15 @@ from pyRDDLGym_jax.core.planner import load_config
30
31
  def main(domain, instance, method, trials=5, iters=20, workers=4):
31
32
 
32
33
  # set up the environment
33
- env = pyRDDLGym.make(domain, instance, vectorized=True, enforce_action_constraints=True)
34
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
34
35
 
35
36
  # load the config file with planner settings
36
37
  abs_path = os.path.dirname(os.path.abspath(__file__))
37
38
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
39
+ if not os.path.isfile(config_path):
40
+ raise_warning(f'Config file {config_path} was not found, '
41
+ f'using default_{method}.cfg.', 'red')
42
+ config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
38
43
  planner_args, plan_args, train_args = load_config(config_path)
39
44
 
40
45
  # define algorithm to perform tuning
@@ -43,8 +48,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
43
48
  elif method == 'drp':
44
49
  tuning_class = JaxParameterTuningDRP
45
50
  elif method == 'replan':
46
- tuning_class = JaxParameterTuningSLPReplan
47
-
51
+ tuning_class = JaxParameterTuningSLPReplan
48
52
  tuning = tuning_class(env=env,
49
53
  train_epochs=train_args['epochs'],
50
54
  timeout_training=train_args['train_seconds'],
@@ -54,6 +58,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
54
58
  num_workers=workers,
55
59
  gp_iters=iters)
56
60
 
61
+ # perform tuning and report best parameters
57
62
  best = tuning.tune(key=train_args['key'], filename=f'gp_{method}',
58
63
  save_plot=True)
59
64
  print(f'best parameters found: {best}')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 0.1
3
+ Version: 0.3
4
4
  Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -0,0 +1,44 @@
1
+ pyRDDLGym_jax/__init__.py,sha256=Cl7DWkrPP64Ofc2ILXnudFOdnCuKs2p0Pm7ykZOOPh4,19
2
+ pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ pyRDDLGym_jax/core/compiler.py,sha256=m7p0CHOU4Wma0cKMu_WQwfoieIQ2pXD68hZ8BFJ970A,89103
4
+ pyRDDLGym_jax/core/logic.py,sha256=zujSHiR5KhTO81E5Zn8Gy_xSzVzfDskFCGvZygFRdMI,21930
5
+ pyRDDLGym_jax/core/planner.py,sha256=1BtU1G3rihRZaMfNu0VtbSl1LXEXu6pT75EkF6-WVnM,101827
6
+ pyRDDLGym_jax/core/simulator.py,sha256=fp6bep3XwwBWED0w7_4qhiwDjkSka6B2prwdNcPRCMc,8329
7
+ pyRDDLGym_jax/core/tuning.py,sha256=Dv0YyOgGnej-zdVymWdkVg0MZjm2lNRfr7gySzFOeow,29589
8
+ pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
10
+ pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
11
+ pyRDDLGym_jax/examples/run_plan.py,sha256=OENf8s-SrMlh7CYXNhanQiau35b4atLBJMNjgP88DCg,2463
12
+ pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
13
+ pyRDDLGym_jax/examples/run_tune.py,sha256=-M4KoBpg5lshQ4mmU0cnLs2i7-ldSIr_OcxHK7YA6bw,3273
14
+ pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=pbkz6ccgk5dHXp7cfYbZNFyJobpGyxUZleCy4fvlmaU,336
15
+ pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=OswO9YD4Xh1pw3R3LkUBb67WLtj5XlE3qnMQ5CKwPsM,332
16
+ pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=FxZ4xcg2j2PzeH-wUseRR280juQN5bJjoyt6PtI1W7c,329
17
+ pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=FTGFwRAGyeRrbDMh_FV8iv8ZHrlj3Htju4pfPNmKIcw,336
18
+ pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=wjtz86_Gz0RfQu3bbrz56PTXL8JMernINx7AtJuZCPs,314
19
+ pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=C_0BFyhGXbtF7N4vyeua2XkORbkj10HELC1GpzM0Uh4,415
20
+ pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg,sha256=Yb4tFzUOj4epCCsofXAZo70lm5C2KzPIzI5PQHsa_Vk,429
21
+ pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=e7j-1Z66o7F-KZDSf2e8TQRWwkXOPRwrRFkIavK8G7g,327
22
+ pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=Z6CxaOxHv4oF6nW7SfSn_HshlQGDlNCPGASTnDTdL7Q,327
23
+ pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=Uy1mrX-AZMS-KBAhWXJ3c_QAhd4bRSWttDoFGYQ08lQ,315
24
+ pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=SM5_U4RwvvucHVAOdMG4vqH0Eg43f3WX9ZlV6aFPgTw,341
25
+ pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=lcqQ7P7X4qAbMlpkKKuYGn2luSZH-yFB7oi-eHj9Qng,332
26
+ pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=kG1-02ScmwsEwX7QIAZTD7si90Mb06b79G5oqcMQ9Hg,316
27
+ pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=yGMBWiVZT8KdZ1PhQ4kIxPvnjht1ss0UheTV-Nt9oaA,364
28
+ pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=9QNl58PyoJYhmwvrhzUxlLEy8vGbmwE6lRuOdvhLjGQ,317
29
+ pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=rrubYvC1q7Ff0ADV0GXtLw-rD9E4m7qfR66qxdYNTD8,339
30
+ pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=DAb-J2KwvJXViRRSHZe8aJwZiPljC28HtrKJPieeUCY,331
31
+ pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=QwKzCAFaErrTCHaJwDPLOxPHpNGNuAKMUoZjLLnMrNc,314
32
+ pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=QiJCJYOrdXXZfOTuPleGswREFxjGlqQSA0rw00YJWWI,318
33
+ pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=PGkgll7h5vhSF13JScKoQ-vpWaAGNJ_PUEhK7jEjNx4,340
34
+ pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=kEDAwsJQ_t9WPzPhIxfS0hRtgOhtFdJFfmPtTTJuwUE,454
35
+ pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=w2wipsA8PE5OBkYVIKajjtCOtiHqmMeY3XQVPAApwFk,371
36
+ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
38
+ pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
39
+ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
40
+ pyRDDLGym_jax-0.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
41
+ pyRDDLGym_jax-0.3.dist-info/METADATA,sha256=e_1MlMdQoqQHW-KA2OSIZzIAQyfe-jDtMOxkIyhmLmI,1085
42
+ pyRDDLGym_jax-0.3.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
43
+ pyRDDLGym_jax-0.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
44
+ pyRDDLGym_jax-0.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (70.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,18 +0,0 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- logic_kwargs={'weight': 10.0}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
6
-
7
- [Optimizer]
8
- method='JaxStraightLinePlan'
9
- method_kwargs={}
10
- optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.005}
12
- batch_size_train=8
13
- batch_size_test=8
14
-
15
- [Training]
16
- key=42
17
- epochs=10000
18
- train_seconds=90
@@ -1,20 +0,0 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- logic_kwargs={'weight': 1000}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
6
-
7
- [Optimizer]
8
- method='JaxStraightLinePlan'
9
- method_kwargs={}
10
- optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.001}
12
- batch_size_train=16
13
- batch_size_test=16
14
- clip_grad=1.0
15
-
16
- [Training]
17
- key=42
18
- epochs=200
19
- train_seconds=30
20
- policy_hyperparams={'advance': 10.0}