pyRDDLGym-jax 2.3__py3-none-any.whl → 2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,10 +19,12 @@
19
19
 
20
20
 
21
21
  import time
22
- from typing import Dict, Optional
22
+ import numpy as np
23
+ from typing import Dict, Optional, Union
23
24
 
24
25
  import jax
25
26
 
27
+ from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
26
28
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
27
29
  from pyRDDLGym.core.debug.exception import (
28
30
  RDDLActionPreconditionNotSatisfiedError,
@@ -35,7 +37,7 @@ from pyRDDLGym.core.simulator import RDDLSimulator
35
37
 
36
38
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
37
39
 
38
- Args = Dict[str, Value]
40
+ Args = Dict[str, Union[np.ndarray, Value]]
39
41
 
40
42
 
41
43
  class JaxRDDLSimulator(RDDLSimulator):
@@ -45,6 +47,7 @@ class JaxRDDLSimulator(RDDLSimulator):
45
47
  raise_error: bool=True,
46
48
  logger: Optional[Logger]=None,
47
49
  keep_tensors: bool=False,
50
+ objects_as_strings: bool=True,
48
51
  **compiler_args) -> None:
49
52
  '''Creates a new simulator for the given RDDL model with Jax as a backend.
50
53
 
@@ -57,6 +60,8 @@ class JaxRDDLSimulator(RDDLSimulator):
57
60
  :param logger: to log information about compilation to file
58
61
  :param keep_tensors: whether the sampler takes actions and
59
62
  returns state in numpy array form
63
+ param objects_as_strings: whether to return object values as strings (defaults
64
+ to integer indices if False)
60
65
  :param **compiler_args: keyword arguments to pass to the Jax compiler
61
66
  '''
62
67
  if key is None:
@@ -67,7 +72,8 @@ class JaxRDDLSimulator(RDDLSimulator):
67
72
 
68
73
  # generate direct sampling with default numpy RNG and operations
69
74
  super(JaxRDDLSimulator, self).__init__(
70
- rddl, logger=logger, keep_tensors=keep_tensors)
75
+ rddl, logger=logger,
76
+ keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
71
77
 
72
78
  def seed(self, seed: int) -> None:
73
79
  super(JaxRDDLSimulator, self).seed(seed)
@@ -84,11 +90,11 @@ class JaxRDDLSimulator(RDDLSimulator):
84
90
  self.levels = compiled.levels
85
91
  self.traced = compiled.traced
86
92
 
87
- self.invariants = jax.tree_map(jax.jit, compiled.invariants)
88
- self.preconds = jax.tree_map(jax.jit, compiled.preconditions)
89
- self.terminals = jax.tree_map(jax.jit, compiled.terminations)
93
+ self.invariants = jax.tree_util.tree_map(jax.jit, compiled.invariants)
94
+ self.preconds = jax.tree_util.tree_map(jax.jit, compiled.preconditions)
95
+ self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
90
96
  self.reward = jax.jit(compiled.reward)
91
- jax_cpfs = jax.tree_map(jax.jit, compiled.cpfs)
97
+ jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
92
98
  self.model_params = compiled.model_params
93
99
 
94
100
  # level analysis
@@ -139,7 +145,6 @@ class JaxRDDLSimulator(RDDLSimulator):
139
145
 
140
146
  def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
141
147
  '''Throws an exception if the action preconditions are not satisfied.'''
142
- actions = self._process_actions(actions)
143
148
  subs = self.subs
144
149
  subs.update(actions)
145
150
 
@@ -180,7 +185,6 @@ class JaxRDDLSimulator(RDDLSimulator):
180
185
  '''
181
186
  rddl = self.rddl
182
187
  keep_tensors = self.keep_tensors
183
- actions = self._process_actions(actions)
184
188
  subs = self.subs
185
189
  subs.update(actions)
186
190
 
@@ -196,20 +200,40 @@ class JaxRDDLSimulator(RDDLSimulator):
196
200
  # update state
197
201
  self.state = {}
198
202
  for (state, next_state) in rddl.next_state.items():
203
+
204
+ # set state = state' for the next epoch
199
205
  subs[state] = subs[next_state]
206
+
207
+ # convert object integer to string representation
208
+ state_values = subs[state]
209
+ if self.objects_as_strings:
210
+ ptype = rddl.variable_ranges[state]
211
+ if ptype not in RDDLValueInitializer.NUMPY_TYPES:
212
+ state_values = rddl.index_to_object_string_array(ptype, state_values)
213
+
214
+ # optional grounding of state dictionary
200
215
  if keep_tensors:
201
- self.state[state] = subs[state]
216
+ self.state[state] = state_values
202
217
  else:
203
- self.state.update(rddl.ground_var_with_values(state, subs[state]))
218
+ self.state.update(rddl.ground_var_with_values(state, state_values))
204
219
 
205
220
  # update observation
206
221
  if self._pomdp:
207
222
  obs = {}
208
223
  for var in rddl.observ_fluents:
224
+
225
+ # convert object integer to string representation
226
+ obs_values = subs[var]
227
+ if self.objects_as_strings:
228
+ ptype = rddl.variable_ranges[var]
229
+ if ptype not in RDDLValueInitializer.NUMPY_TYPES:
230
+ obs_values = rddl.index_to_object_string_array(ptype, obs_values)
231
+
232
+ # optional grounding of observ-fluent dictionary
209
233
  if keep_tensors:
210
- obs[var] = subs[var]
234
+ obs[var] = obs_values
211
235
  else:
212
- obs.update(rddl.ground_var_with_values(var, subs[var]))
236
+ obs.update(rddl.ground_var_with_values(var, obs_values))
213
237
  else:
214
238
  obs = self.state
215
239
 
@@ -18,6 +18,7 @@ import datetime
18
18
  import threading
19
19
  import multiprocessing
20
20
  import os
21
+ import termcolor
21
22
  import time
22
23
  import traceback
23
24
  from typing import Any, Callable, Dict, Iterable, Optional, Tuple
@@ -45,8 +46,7 @@ try:
45
46
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
46
47
  except Exception:
47
48
  raise_warning('Failed to load the dashboard visualization tool: '
48
- 'please make sure you have installed the required packages.',
49
- 'red')
49
+ 'please make sure you have installed the required packages.', 'red')
50
50
  traceback.print_exc()
51
51
  JaxPlannerDashboard = None
52
52
 
@@ -159,24 +159,24 @@ class JaxParameterTuning:
159
159
  kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
160
160
  return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
161
161
 
162
- def summarize_hyperparameters(self) -> None:
162
+ def summarize_hyperparameters(self) -> str:
163
163
  hyper_params_table = []
164
164
  for (_, param) in self.hyperparams_dict.items():
165
165
  hyper_params_table.append(f' {str(param)}')
166
166
  hyper_params_table = '\n'.join(hyper_params_table)
167
- print(f'hyperparameter optimizer parameters:\n'
168
- f' tuned_hyper_parameters =\n{hyper_params_table}\n'
169
- f' initialization_args ={self.gp_init_kwargs}\n'
170
- f' gp_params ={self.gp_params}\n'
171
- f' tuning_iterations ={self.gp_iters}\n'
172
- f' tuning_timeout ={self.timeout_tuning}\n'
173
- f' tuning_batch_size ={self.num_workers}\n'
174
- f' mp_pool_context_type ={self.pool_context}\n'
175
- f' mp_pool_poll_frequency ={self.poll_frequency}\n'
176
- f'meta-objective parameters:\n'
177
- f' planning_trials_per_iter ={self.eval_trials}\n'
178
- f' rollouts_per_trial ={self.rollouts_per_trial}\n'
179
- f' acquisition_fn ={self.acquisition}')
167
+ return (f'hyperparameter optimizer parameters:\n'
168
+ f' tuned_hyper_parameters =\n{hyper_params_table}\n'
169
+ f' initialization_args ={self.gp_init_kwargs}\n'
170
+ f' gp_params ={self.gp_params}\n'
171
+ f' tuning_iterations ={self.gp_iters}\n'
172
+ f' tuning_timeout ={self.timeout_tuning}\n'
173
+ f' tuning_batch_size ={self.num_workers}\n'
174
+ f' mp_pool_context_type ={self.pool_context}\n'
175
+ f' mp_pool_poll_frequency ={self.poll_frequency}\n'
176
+ f'meta-objective parameters:\n'
177
+ f' planning_trials_per_iter ={self.eval_trials}\n'
178
+ f' rollouts_per_trial ={self.rollouts_per_trial}\n'
179
+ f' acquisition_fn ={self.acquisition}')
180
180
 
181
181
  @staticmethod
182
182
  def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
@@ -346,6 +346,7 @@ class JaxParameterTuning:
346
346
 
347
347
  # remove keywords that should not be in the tuner
348
348
  train_args.pop('dashboard', None)
349
+ planner_args.pop('parallel_updates', None)
349
350
 
350
351
  # initialize env for evaluation (need fresh copy to avoid concurrency)
351
352
  env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
@@ -368,18 +369,32 @@ class JaxParameterTuning:
368
369
 
369
370
  def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
370
371
  '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
371
- print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
372
+ print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
372
373
 
373
- def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
374
- '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
374
+ def tune(self, key: int,
375
+ log_file: Optional[str]=None,
376
+ show_dashboard: bool=False,
377
+ print_hyperparams: bool=False) -> ParameterValues:
378
+ '''Tunes the hyper-parameters for Jax planner, returns the best found.
375
379
 
376
- self.summarize_hyperparameters()
380
+ :param key: RNG key to seed the hyper-parameter optimizer
381
+ :param log_file: optional path to file where tuning progress will be saved
382
+ :param show_dashboard: whether to display tuning results in a dashboard
383
+ :param print_hyperparams: whether to print a hyper-parameter summary of the
384
+ optimizer
385
+ '''
377
386
 
378
- # clear and prepare output file
379
- with open(log_file, 'w', newline='') as file:
380
- writer = csv.writer(file)
381
- writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
387
+ if self.verbose:
388
+ print(JaxBackpropPlanner.summarize_system())
389
+ if print_hyperparams:
390
+ print(self.summarize_hyperparameters())
382
391
 
392
+ # clear and prepare output file
393
+ if log_file is not None:
394
+ with open(log_file, 'w', newline='') as file:
395
+ writer = csv.writer(file)
396
+ writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
397
+
383
398
  # create a dash-board for visualizing experiment runs
384
399
  if show_dashboard and JaxPlannerDashboard is not None:
385
400
  dashboard = JaxPlannerDashboard()
@@ -445,13 +460,15 @@ class JaxParameterTuning:
445
460
  # check if there is enough time left for another iteration
446
461
  elapsed = time.time() - start_time
447
462
  if elapsed >= self.timeout_tuning:
448
- print(f'global time limit reached at iteration {it}, aborting')
463
+ message = termcolor.colored(
464
+ f'[INFO] Global time limit reached at iteration {it}.', 'green')
465
+ print(message)
449
466
  break
450
467
 
451
468
  # continue with next iteration
452
469
  print('\n' + '*' * 80 +
453
470
  f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
454
- f'starting iteration {it + 1}' +
471
+ f'Starting iteration {it + 1}' +
455
472
  '\n' + '*' * 80)
456
473
  key, *subkeys = jax.random.split(key, num=num_workers + 1)
457
474
  rows = [None] * num_workers
@@ -507,15 +524,19 @@ class JaxParameterTuning:
507
524
 
508
525
  # print best parameter if found
509
526
  if best_target > old_best_target:
510
- print(f'* found new best average reward {best_target:.6f}')
527
+ message = termcolor.colored(
528
+ f'[INFO] Found new best average reward {best_target:.6f}.',
529
+ 'green')
530
+ print(message)
511
531
 
512
532
  # tune the optimizer here
513
533
  self.tune_optimizer(optimizer)
514
534
 
515
535
  # write results of all processes in current iteration to file
516
- with open(log_file, 'a', newline='') as file:
517
- writer = csv.writer(file)
518
- writer.writerows(rows)
536
+ if log_file is not None:
537
+ with open(log_file, 'a', newline='') as file:
538
+ writer = csv.writer(file)
539
+ writer.writerows(rows)
519
540
 
520
541
  # update the dashboard tuning
521
542
  if show_dashboard:
@@ -528,7 +549,7 @@ class JaxParameterTuning:
528
549
 
529
550
  # print summary of results
530
551
  elapsed = time.time() - start_time
531
- print(f'summary of hyper-parameter optimization:\n'
552
+ print(f'Summary of hyper-parameter optimization:\n'
532
553
  f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
533
554
  f' iterations ={it + 1}\n'
534
555
  f' best_hyper_parameters={best_params}\n'
@@ -2,24 +2,56 @@ import argparse
2
2
 
3
3
  from pyRDDLGym_jax.examples import run_plan, run_tune
4
4
 
5
+ EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
6
+
5
7
  def main():
6
- parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
8
+ parser = argparse.ArgumentParser(prog='jaxplan',
9
+ description="command line parser for the jaxplan planner",
10
+ epilog=EPILOG)
7
11
  subparsers = parser.add_subparsers(dest="jaxplan", required=True)
8
12
 
9
13
  # planning
10
- parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
11
- parser_plan.add_argument('args', nargs=argparse.REMAINDER)
14
+ parser_plan = subparsers.add_parser("plan",
15
+ help="execute jaxplan on a specified RDDL problem",
16
+ epilog=EPILOG)
17
+ parser_plan.add_argument('domain', type=str,
18
+ help='name of domain in rddlrepository or a valid file path')
19
+ parser_plan.add_argument('instance', type=str,
20
+ help='name of instance in rddlrepository or a valid file path')
21
+ parser_plan.add_argument('method', type=str,
22
+ help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
23
+ parser_plan.add_argument('-e', '--episodes', type=int, required=False, default=1,
24
+ help='number of training or evaluation episodes')
12
25
 
13
26
  # tuning
14
- parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
15
- parser_tune.add_argument('args', nargs=argparse.REMAINDER)
27
+ parser_tune = subparsers.add_parser("tune",
28
+ help="tune jaxplan on a specified RDDL problem",
29
+ epilog=EPILOG)
30
+ parser_tune.add_argument('domain', type=str,
31
+ help='name of domain in rddlrepository or a valid file path')
32
+ parser_tune.add_argument('instance', type=str,
33
+ help='name of instance in rddlrepository or a valid file path')
34
+ parser_tune.add_argument('method', type=str,
35
+ help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
36
+ parser_tune.add_argument('-t', '--trials', type=int, required=False, default=5,
37
+ help='number of evaluation rollouts per hyper-parameter choice')
38
+ parser_tune.add_argument('-i', '--iters', type=int, required=False, default=20,
39
+ help='number of iterations of bayesian optimization')
40
+ parser_tune.add_argument('-w', '--workers', type=int, required=False, default=4,
41
+ help='number of parallel hyper-parameters to evaluate per iteration')
42
+ parser_tune.add_argument('-d', '--dashboard', type=bool, required=False, default=False,
43
+ help='show the dashboard')
44
+ parser_tune.add_argument('-f', '--filepath', type=str, required=False, default='',
45
+ help='where to save the config file of the best hyper-parameters')
16
46
 
17
47
  # dispatch
18
48
  args = parser.parse_args()
19
49
  if args.jaxplan == "plan":
20
- run_plan.run_from_args(args.args)
50
+ run_plan.main(args.domain, args.instance, args.method, args.episodes)
21
51
  elif args.jaxplan == "tune":
22
- run_tune.run_from_args(args.args)
52
+ run_tune.main(args.domain, args.instance, args.method,
53
+ args.trials, args.iters, args.workers, args.dashboard,
54
+ args.filepath)
23
55
  else:
24
56
  parser.print_help()
25
57
 
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ print_warnings=False
14
15
 
15
16
  [Training]
16
17
  train_seconds=30
@@ -12,6 +12,7 @@ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
14
  rollout_horizon=ROLLOUT_HORIZON_TUNE
15
+ print_warnings=False
15
16
 
16
17
  [Training]
17
18
  train_seconds=1
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ print_warnings=False
14
15
 
15
16
  [Training]
16
17
  train_seconds=30
@@ -26,7 +26,7 @@ from pyRDDLGym_jax.core.planner import (
26
26
  )
27
27
 
28
28
 
29
- def main(domain, instance, method, episodes=1):
29
+ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
30
30
 
31
31
  # set up the environment
32
32
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -36,8 +36,8 @@ def main(domain, instance, method, episodes=1):
36
36
  abs_path = os.path.dirname(os.path.abspath(__file__))
37
37
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
38
38
  if not os.path.isfile(config_path):
39
- raise_warning(f'Config file {config_path} was not found, '
40
- f'using default_{method}.cfg.', 'red')
39
+ raise_warning(f'[WARN] Config file {config_path} was not found, '
40
+ f'using default_{method}.cfg.', 'yellow')
41
41
  config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
42
42
  elif os.path.isfile(method):
43
43
  config_path = method
@@ -31,8 +31,8 @@ def main(domain, instance, method, episodes=1):
31
31
  abs_path = os.path.dirname(os.path.abspath(__file__))
32
32
  config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
33
33
  if not os.path.isfile(config_path):
34
- raise_warning(f'Config file {config_path} was not found, '
35
- f'using default_slp.cfg.', 'red')
34
+ raise_warning(f'[WARN] Config file {config_path} was not found, '
35
+ f'using default_slp.cfg.', 'yellow')
36
36
  config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
37
37
  planner_args, _, train_args = load_config(config_path)
38
38
 
@@ -36,7 +36,9 @@ def power_10(x):
36
36
  return 10.0 ** x
37
37
 
38
38
 
39
- def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=False):
39
+ def main(domain: str, instance: str, method: str,
40
+ trials: int=5, iters: int=20, workers: int=4, dashboard: bool=False,
41
+ filepath: str='') -> None:
40
42
 
41
43
  # set up the environment
42
44
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -68,6 +70,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
68
70
  tuning.tune(key=42,
69
71
  log_file=f'gp_{method}_{domain}_{instance}.csv',
70
72
  show_dashboard=dashboard)
73
+ if filepath is not None and filepath:
74
+ with open(filepath, "w") as file:
75
+ file.write(tuning.best_config)
71
76
 
72
77
  # evaluate the agent on the best parameters
73
78
  planner_args, _, train_args = load_config_from_string(tuning.best_config)
@@ -80,7 +85,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
80
85
 
81
86
  def run_from_args(args):
82
87
  if len(args) < 3:
83
- print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
88
+ print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>] [<filepath>]')
84
89
  exit(1)
85
90
  if args[2] not in ['drp', 'slp', 'replan']:
86
91
  print('<method> in [drp, slp, replan]')
@@ -90,6 +95,7 @@ def run_from_args(args):
90
95
  if len(args) >= 5: kwargs['iters'] = int(args[4])
91
96
  if len(args) >= 6: kwargs['workers'] = int(args[5])
92
97
  if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
98
+ if len(args) >= 8: kwargs['filepath'] = bool(args[7])
93
99
  main(**kwargs)
94
100
 
95
101
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.3
3
+ Version: 2.5
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -39,6 +39,7 @@ Dynamic: description
39
39
  Dynamic: description-content-type
40
40
  Dynamic: home-page
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: provides-extra
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
116
117
  A basic run script is provided to train JaxPlan on any RDDL problem:
117
118
 
118
119
  ```shell
119
- jaxplan plan <domain> <instance> <method> <episodes>
120
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
120
121
  ```
121
122
 
122
123
  where:
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
241
242
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
243
 
243
244
  ```shell
244
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
245
246
  ```
246
247
 
247
248
  where:
@@ -251,7 +252,8 @@ where:
251
252
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
253
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
254
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
256
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
255
257
 
256
258
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
259
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
291
293
  with open('path/to/config.cfg', 'r') as file:
292
294
  config_template = file.read()
293
295
 
294
- # map parameters in the config that will be tuned
296
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
295
297
  def power_10(x):
296
- return 10.0 ** x
297
-
298
- hyperparams = [
299
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
300
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
301
- ]
298
+ return 10.0 ** x
299
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
300
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
302
301
 
303
302
  # build the tuner and tune
304
303
  tuning = JaxParameterTuning(env=env,
305
- config_template=config_template,
306
- hyperparams=hyperparams,
307
- online=False,
308
- eval_trials=trials,
309
- num_workers=workers,
310
- gp_iters=iters)
304
+ config_template=config_template, hyperparams=hyperparams,
305
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
311
306
  tuning.tune(key=42, log_file='path/to/log.csv')
312
307
  ```
313
308
 
@@ -1,20 +1,20 @@
1
- pyRDDLGym_jax/__init__.py,sha256=ab_pLSTaKv50-5b6lazl75TqhQi0bNsErQ8JlBepVII,19
2
- pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
1
+ pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
2
+ pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=fLOdJED-Cxtm_IT4LRiZ461Alp9Qjr0vBsOnw1s__EY,82612
5
- pyRDDLGym_jax/core/logic.py,sha256=0NNm0OaeKv46K0VNY6vL0PHOUFZPNxqQLOvQYkHCswM,56093
6
- pyRDDLGym_jax/core/planner.py,sha256=0rluBXKGNHRPEPfegOWcx9__cJHr8KjZdDJtG7i1JjI,122793
7
- pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
8
- pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
4
+ pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
5
+ pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
6
+ pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
7
+ pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
8
+ pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
9
9
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
10
10
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
12
12
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
14
14
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
15
- pyRDDLGym_jax/examples/run_plan.py,sha256=v2AvwgIa4Ejr626vBOgWFJIQvay3IPKWno02ztIFCYc,2768
16
- pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
17
- pyRDDLGym_jax/examples/run_tune.py,sha256=WbGO8RudIK-cPMAMKvI8NbFQAqkG-Blbnta3Efsep6c,3828
15
+ pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
16
+ pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
17
+ pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
18
18
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
19
19
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
20
20
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
@@ -38,12 +38,12 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
38
38
  pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
39
39
  pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
40
40
  pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
41
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
42
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
43
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
44
- pyrddlgym_jax-2.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyrddlgym_jax-2.3.dist-info/METADATA,sha256=MS6tckyg-bAQBGZJ112VQPZm5at660EfhntCnfrlUbE,17021
46
- pyrddlgym_jax-2.3.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
47
- pyrddlgym_jax-2.3.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyrddlgym_jax-2.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyrddlgym_jax-2.3.dist-info/RECORD,,
41
+ pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
42
+ pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
43
+ pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
44
+ pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
+ pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
46
+ pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
+ pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
+ pyrddlgym_jax-2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5