pyRDDLGym-jax 2.3__py3-none-any.whl → 2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +10 -7
- pyRDDLGym_jax/core/logic.py +117 -66
- pyRDDLGym_jax/core/planner.py +585 -248
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +52 -31
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +3 -3
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/METADATA +13 -18
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/RECORD +19 -19
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -19,10 +19,12 @@
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
|
-
|
|
22
|
+
import numpy as np
|
|
23
|
+
from typing import Dict, Optional, Union
|
|
23
24
|
|
|
24
25
|
import jax
|
|
25
26
|
|
|
27
|
+
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
26
28
|
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
27
29
|
from pyRDDLGym.core.debug.exception import (
|
|
28
30
|
RDDLActionPreconditionNotSatisfiedError,
|
|
@@ -35,7 +37,7 @@ from pyRDDLGym.core.simulator import RDDLSimulator
|
|
|
35
37
|
|
|
36
38
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
37
39
|
|
|
38
|
-
Args = Dict[str, Value]
|
|
40
|
+
Args = Dict[str, Union[np.ndarray, Value]]
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class JaxRDDLSimulator(RDDLSimulator):
|
|
@@ -45,6 +47,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
45
47
|
raise_error: bool=True,
|
|
46
48
|
logger: Optional[Logger]=None,
|
|
47
49
|
keep_tensors: bool=False,
|
|
50
|
+
objects_as_strings: bool=True,
|
|
48
51
|
**compiler_args) -> None:
|
|
49
52
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
50
53
|
|
|
@@ -57,6 +60,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
57
60
|
:param logger: to log information about compilation to file
|
|
58
61
|
:param keep_tensors: whether the sampler takes actions and
|
|
59
62
|
returns state in numpy array form
|
|
63
|
+
param objects_as_strings: whether to return object values as strings (defaults
|
|
64
|
+
to integer indices if False)
|
|
60
65
|
:param **compiler_args: keyword arguments to pass to the Jax compiler
|
|
61
66
|
'''
|
|
62
67
|
if key is None:
|
|
@@ -67,7 +72,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
67
72
|
|
|
68
73
|
# generate direct sampling with default numpy RNG and operations
|
|
69
74
|
super(JaxRDDLSimulator, self).__init__(
|
|
70
|
-
rddl, logger=logger,
|
|
75
|
+
rddl, logger=logger,
|
|
76
|
+
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
|
|
71
77
|
|
|
72
78
|
def seed(self, seed: int) -> None:
|
|
73
79
|
super(JaxRDDLSimulator, self).seed(seed)
|
|
@@ -84,11 +90,11 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
84
90
|
self.levels = compiled.levels
|
|
85
91
|
self.traced = compiled.traced
|
|
86
92
|
|
|
87
|
-
self.invariants = jax.tree_map(jax.jit, compiled.invariants)
|
|
88
|
-
self.preconds = jax.tree_map(jax.jit, compiled.preconditions)
|
|
89
|
-
self.terminals = jax.tree_map(jax.jit, compiled.terminations)
|
|
93
|
+
self.invariants = jax.tree_util.tree_map(jax.jit, compiled.invariants)
|
|
94
|
+
self.preconds = jax.tree_util.tree_map(jax.jit, compiled.preconditions)
|
|
95
|
+
self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
|
|
90
96
|
self.reward = jax.jit(compiled.reward)
|
|
91
|
-
jax_cpfs = jax.tree_map(jax.jit, compiled.cpfs)
|
|
97
|
+
jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
|
|
92
98
|
self.model_params = compiled.model_params
|
|
93
99
|
|
|
94
100
|
# level analysis
|
|
@@ -139,7 +145,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
139
145
|
|
|
140
146
|
def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
|
|
141
147
|
'''Throws an exception if the action preconditions are not satisfied.'''
|
|
142
|
-
actions = self._process_actions(actions)
|
|
143
148
|
subs = self.subs
|
|
144
149
|
subs.update(actions)
|
|
145
150
|
|
|
@@ -180,7 +185,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
180
185
|
'''
|
|
181
186
|
rddl = self.rddl
|
|
182
187
|
keep_tensors = self.keep_tensors
|
|
183
|
-
actions = self._process_actions(actions)
|
|
184
188
|
subs = self.subs
|
|
185
189
|
subs.update(actions)
|
|
186
190
|
|
|
@@ -196,20 +200,40 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
196
200
|
# update state
|
|
197
201
|
self.state = {}
|
|
198
202
|
for (state, next_state) in rddl.next_state.items():
|
|
203
|
+
|
|
204
|
+
# set state = state' for the next epoch
|
|
199
205
|
subs[state] = subs[next_state]
|
|
206
|
+
|
|
207
|
+
# convert object integer to string representation
|
|
208
|
+
state_values = subs[state]
|
|
209
|
+
if self.objects_as_strings:
|
|
210
|
+
ptype = rddl.variable_ranges[state]
|
|
211
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
212
|
+
state_values = rddl.index_to_object_string_array(ptype, state_values)
|
|
213
|
+
|
|
214
|
+
# optional grounding of state dictionary
|
|
200
215
|
if keep_tensors:
|
|
201
|
-
self.state[state] =
|
|
216
|
+
self.state[state] = state_values
|
|
202
217
|
else:
|
|
203
|
-
self.state.update(rddl.ground_var_with_values(state,
|
|
218
|
+
self.state.update(rddl.ground_var_with_values(state, state_values))
|
|
204
219
|
|
|
205
220
|
# update observation
|
|
206
221
|
if self._pomdp:
|
|
207
222
|
obs = {}
|
|
208
223
|
for var in rddl.observ_fluents:
|
|
224
|
+
|
|
225
|
+
# convert object integer to string representation
|
|
226
|
+
obs_values = subs[var]
|
|
227
|
+
if self.objects_as_strings:
|
|
228
|
+
ptype = rddl.variable_ranges[var]
|
|
229
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
230
|
+
obs_values = rddl.index_to_object_string_array(ptype, obs_values)
|
|
231
|
+
|
|
232
|
+
# optional grounding of observ-fluent dictionary
|
|
209
233
|
if keep_tensors:
|
|
210
|
-
obs[var] =
|
|
234
|
+
obs[var] = obs_values
|
|
211
235
|
else:
|
|
212
|
-
obs.update(rddl.ground_var_with_values(var,
|
|
236
|
+
obs.update(rddl.ground_var_with_values(var, obs_values))
|
|
213
237
|
else:
|
|
214
238
|
obs = self.state
|
|
215
239
|
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -18,6 +18,7 @@ import datetime
|
|
|
18
18
|
import threading
|
|
19
19
|
import multiprocessing
|
|
20
20
|
import os
|
|
21
|
+
import termcolor
|
|
21
22
|
import time
|
|
22
23
|
import traceback
|
|
23
24
|
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
@@ -45,8 +46,7 @@ try:
|
|
|
45
46
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
46
47
|
except Exception:
|
|
47
48
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
48
|
-
'please make sure you have installed the required packages.',
|
|
49
|
-
'red')
|
|
49
|
+
'please make sure you have installed the required packages.', 'red')
|
|
50
50
|
traceback.print_exc()
|
|
51
51
|
JaxPlannerDashboard = None
|
|
52
52
|
|
|
@@ -159,24 +159,24 @@ class JaxParameterTuning:
|
|
|
159
159
|
kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
|
|
160
160
|
return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
|
|
161
161
|
|
|
162
|
-
def summarize_hyperparameters(self) ->
|
|
162
|
+
def summarize_hyperparameters(self) -> str:
|
|
163
163
|
hyper_params_table = []
|
|
164
164
|
for (_, param) in self.hyperparams_dict.items():
|
|
165
165
|
hyper_params_table.append(f' {str(param)}')
|
|
166
166
|
hyper_params_table = '\n'.join(hyper_params_table)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
167
|
+
return (f'hyperparameter optimizer parameters:\n'
|
|
168
|
+
f' tuned_hyper_parameters =\n{hyper_params_table}\n'
|
|
169
|
+
f' initialization_args ={self.gp_init_kwargs}\n'
|
|
170
|
+
f' gp_params ={self.gp_params}\n'
|
|
171
|
+
f' tuning_iterations ={self.gp_iters}\n'
|
|
172
|
+
f' tuning_timeout ={self.timeout_tuning}\n'
|
|
173
|
+
f' tuning_batch_size ={self.num_workers}\n'
|
|
174
|
+
f' mp_pool_context_type ={self.pool_context}\n'
|
|
175
|
+
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
176
|
+
f'meta-objective parameters:\n'
|
|
177
|
+
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
178
|
+
f' rollouts_per_trial ={self.rollouts_per_trial}\n'
|
|
179
|
+
f' acquisition_fn ={self.acquisition}')
|
|
180
180
|
|
|
181
181
|
@staticmethod
|
|
182
182
|
def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
|
|
@@ -346,6 +346,7 @@ class JaxParameterTuning:
|
|
|
346
346
|
|
|
347
347
|
# remove keywords that should not be in the tuner
|
|
348
348
|
train_args.pop('dashboard', None)
|
|
349
|
+
planner_args.pop('parallel_updates', None)
|
|
349
350
|
|
|
350
351
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
351
352
|
env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
|
|
@@ -368,18 +369,32 @@ class JaxParameterTuning:
|
|
|
368
369
|
|
|
369
370
|
def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
|
|
370
371
|
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
371
|
-
print(
|
|
372
|
+
print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
|
|
372
373
|
|
|
373
|
-
def tune(self, key: int,
|
|
374
|
-
|
|
374
|
+
def tune(self, key: int,
|
|
375
|
+
log_file: Optional[str]=None,
|
|
376
|
+
show_dashboard: bool=False,
|
|
377
|
+
print_hyperparams: bool=False) -> ParameterValues:
|
|
378
|
+
'''Tunes the hyper-parameters for Jax planner, returns the best found.
|
|
375
379
|
|
|
376
|
-
|
|
380
|
+
:param key: RNG key to seed the hyper-parameter optimizer
|
|
381
|
+
:param log_file: optional path to file where tuning progress will be saved
|
|
382
|
+
:param show_dashboard: whether to display tuning results in a dashboard
|
|
383
|
+
:param print_hyperparams: whether to print a hyper-parameter summary of the
|
|
384
|
+
optimizer
|
|
385
|
+
'''
|
|
377
386
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
387
|
+
if self.verbose:
|
|
388
|
+
print(JaxBackpropPlanner.summarize_system())
|
|
389
|
+
if print_hyperparams:
|
|
390
|
+
print(self.summarize_hyperparameters())
|
|
382
391
|
|
|
392
|
+
# clear and prepare output file
|
|
393
|
+
if log_file is not None:
|
|
394
|
+
with open(log_file, 'w', newline='') as file:
|
|
395
|
+
writer = csv.writer(file)
|
|
396
|
+
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
397
|
+
|
|
383
398
|
# create a dash-board for visualizing experiment runs
|
|
384
399
|
if show_dashboard and JaxPlannerDashboard is not None:
|
|
385
400
|
dashboard = JaxPlannerDashboard()
|
|
@@ -445,13 +460,15 @@ class JaxParameterTuning:
|
|
|
445
460
|
# check if there is enough time left for another iteration
|
|
446
461
|
elapsed = time.time() - start_time
|
|
447
462
|
if elapsed >= self.timeout_tuning:
|
|
448
|
-
|
|
463
|
+
message = termcolor.colored(
|
|
464
|
+
f'[INFO] Global time limit reached at iteration {it}.', 'green')
|
|
465
|
+
print(message)
|
|
449
466
|
break
|
|
450
467
|
|
|
451
468
|
# continue with next iteration
|
|
452
469
|
print('\n' + '*' * 80 +
|
|
453
470
|
f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
|
|
454
|
-
f'
|
|
471
|
+
f'Starting iteration {it + 1}' +
|
|
455
472
|
'\n' + '*' * 80)
|
|
456
473
|
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
457
474
|
rows = [None] * num_workers
|
|
@@ -507,15 +524,19 @@ class JaxParameterTuning:
|
|
|
507
524
|
|
|
508
525
|
# print best parameter if found
|
|
509
526
|
if best_target > old_best_target:
|
|
510
|
-
|
|
527
|
+
message = termcolor.colored(
|
|
528
|
+
f'[INFO] Found new best average reward {best_target:.6f}.',
|
|
529
|
+
'green')
|
|
530
|
+
print(message)
|
|
511
531
|
|
|
512
532
|
# tune the optimizer here
|
|
513
533
|
self.tune_optimizer(optimizer)
|
|
514
534
|
|
|
515
535
|
# write results of all processes in current iteration to file
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
536
|
+
if log_file is not None:
|
|
537
|
+
with open(log_file, 'a', newline='') as file:
|
|
538
|
+
writer = csv.writer(file)
|
|
539
|
+
writer.writerows(rows)
|
|
519
540
|
|
|
520
541
|
# update the dashboard tuning
|
|
521
542
|
if show_dashboard:
|
|
@@ -528,7 +549,7 @@ class JaxParameterTuning:
|
|
|
528
549
|
|
|
529
550
|
# print summary of results
|
|
530
551
|
elapsed = time.time() - start_time
|
|
531
|
-
print(f'
|
|
552
|
+
print(f'Summary of hyper-parameter optimization:\n'
|
|
532
553
|
f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
|
|
533
554
|
f' iterations ={it + 1}\n'
|
|
534
555
|
f' best_hyper_parameters={best_params}\n'
|
pyRDDLGym_jax/entry_point.py
CHANGED
|
@@ -2,24 +2,56 @@ import argparse
|
|
|
2
2
|
|
|
3
3
|
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
4
|
|
|
5
|
+
EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
|
|
6
|
+
|
|
5
7
|
def main():
|
|
6
|
-
parser = argparse.ArgumentParser(
|
|
8
|
+
parser = argparse.ArgumentParser(prog='jaxplan',
|
|
9
|
+
description="command line parser for the jaxplan planner",
|
|
10
|
+
epilog=EPILOG)
|
|
7
11
|
subparsers = parser.add_subparsers(dest="jaxplan", required=True)
|
|
8
12
|
|
|
9
13
|
# planning
|
|
10
|
-
parser_plan = subparsers.add_parser("plan",
|
|
11
|
-
|
|
14
|
+
parser_plan = subparsers.add_parser("plan",
|
|
15
|
+
help="execute jaxplan on a specified RDDL problem",
|
|
16
|
+
epilog=EPILOG)
|
|
17
|
+
parser_plan.add_argument('domain', type=str,
|
|
18
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
19
|
+
parser_plan.add_argument('instance', type=str,
|
|
20
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
21
|
+
parser_plan.add_argument('method', type=str,
|
|
22
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
23
|
+
parser_plan.add_argument('-e', '--episodes', type=int, required=False, default=1,
|
|
24
|
+
help='number of training or evaluation episodes')
|
|
12
25
|
|
|
13
26
|
# tuning
|
|
14
|
-
parser_tune = subparsers.add_parser("tune",
|
|
15
|
-
|
|
27
|
+
parser_tune = subparsers.add_parser("tune",
|
|
28
|
+
help="tune jaxplan on a specified RDDL problem",
|
|
29
|
+
epilog=EPILOG)
|
|
30
|
+
parser_tune.add_argument('domain', type=str,
|
|
31
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
32
|
+
parser_tune.add_argument('instance', type=str,
|
|
33
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
34
|
+
parser_tune.add_argument('method', type=str,
|
|
35
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
36
|
+
parser_tune.add_argument('-t', '--trials', type=int, required=False, default=5,
|
|
37
|
+
help='number of evaluation rollouts per hyper-parameter choice')
|
|
38
|
+
parser_tune.add_argument('-i', '--iters', type=int, required=False, default=20,
|
|
39
|
+
help='number of iterations of bayesian optimization')
|
|
40
|
+
parser_tune.add_argument('-w', '--workers', type=int, required=False, default=4,
|
|
41
|
+
help='number of parallel hyper-parameters to evaluate per iteration')
|
|
42
|
+
parser_tune.add_argument('-d', '--dashboard', type=bool, required=False, default=False,
|
|
43
|
+
help='show the dashboard')
|
|
44
|
+
parser_tune.add_argument('-f', '--filepath', type=str, required=False, default='',
|
|
45
|
+
help='where to save the config file of the best hyper-parameters')
|
|
16
46
|
|
|
17
47
|
# dispatch
|
|
18
48
|
args = parser.parse_args()
|
|
19
49
|
if args.jaxplan == "plan":
|
|
20
|
-
run_plan.
|
|
50
|
+
run_plan.main(args.domain, args.instance, args.method, args.episodes)
|
|
21
51
|
elif args.jaxplan == "tune":
|
|
22
|
-
run_tune.
|
|
52
|
+
run_tune.main(args.domain, args.instance, args.method,
|
|
53
|
+
args.trials, args.iters, args.workers, args.dashboard,
|
|
54
|
+
args.filepath)
|
|
23
55
|
else:
|
|
24
56
|
parser.print_help()
|
|
25
57
|
|
|
@@ -26,7 +26,7 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def main(domain, instance, method, episodes=1):
|
|
29
|
+
def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
30
30
|
|
|
31
31
|
# set up the environment
|
|
32
32
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -36,8 +36,8 @@ def main(domain, instance, method, episodes=1):
|
|
|
36
36
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
37
37
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
38
38
|
if not os.path.isfile(config_path):
|
|
39
|
-
raise_warning(f'Config file {config_path} was not found, '
|
|
40
|
-
f'using default_{method}.cfg.', '
|
|
39
|
+
raise_warning(f'[WARN] Config file {config_path} was not found, '
|
|
40
|
+
f'using default_{method}.cfg.', 'yellow')
|
|
41
41
|
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
42
42
|
elif os.path.isfile(method):
|
|
43
43
|
config_path = method
|
|
@@ -31,8 +31,8 @@ def main(domain, instance, method, episodes=1):
|
|
|
31
31
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
32
32
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
|
|
33
33
|
if not os.path.isfile(config_path):
|
|
34
|
-
raise_warning(f'Config file {config_path} was not found, '
|
|
35
|
-
f'using default_slp.cfg.', '
|
|
34
|
+
raise_warning(f'[WARN] Config file {config_path} was not found, '
|
|
35
|
+
f'using default_slp.cfg.', 'yellow')
|
|
36
36
|
config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
|
|
37
37
|
planner_args, _, train_args = load_config(config_path)
|
|
38
38
|
|
|
@@ -36,7 +36,9 @@ def power_10(x):
|
|
|
36
36
|
return 10.0 ** x
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def main(domain, instance, method
|
|
39
|
+
def main(domain: str, instance: str, method: str,
|
|
40
|
+
trials: int=5, iters: int=20, workers: int=4, dashboard: bool=False,
|
|
41
|
+
filepath: str='') -> None:
|
|
40
42
|
|
|
41
43
|
# set up the environment
|
|
42
44
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -68,6 +70,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
68
70
|
tuning.tune(key=42,
|
|
69
71
|
log_file=f'gp_{method}_{domain}_{instance}.csv',
|
|
70
72
|
show_dashboard=dashboard)
|
|
73
|
+
if filepath is not None and filepath:
|
|
74
|
+
with open(filepath, "w") as file:
|
|
75
|
+
file.write(tuning.best_config)
|
|
71
76
|
|
|
72
77
|
# evaluate the agent on the best parameters
|
|
73
78
|
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
@@ -80,7 +85,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
80
85
|
|
|
81
86
|
def run_from_args(args):
|
|
82
87
|
if len(args) < 3:
|
|
83
|
-
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
|
|
88
|
+
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>] [<filepath>]')
|
|
84
89
|
exit(1)
|
|
85
90
|
if args[2] not in ['drp', 'slp', 'replan']:
|
|
86
91
|
print('<method> in [drp, slp, replan]')
|
|
@@ -90,6 +95,7 @@ def run_from_args(args):
|
|
|
90
95
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
91
96
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
92
97
|
if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
|
|
98
|
+
if len(args) >= 8: kwargs['filepath'] = bool(args[7])
|
|
93
99
|
main(**kwargs)
|
|
94
100
|
|
|
95
101
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.5
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -39,6 +39,7 @@ Dynamic: description
|
|
|
39
39
|
Dynamic: description-content-type
|
|
40
40
|
Dynamic: home-page
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: provides-extra
|
|
43
44
|
Dynamic: requires-dist
|
|
44
45
|
Dynamic: requires-python
|
|
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
116
117
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
117
118
|
|
|
118
119
|
```shell
|
|
119
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
120
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
120
121
|
```
|
|
121
122
|
|
|
122
123
|
where:
|
|
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
|
|
|
241
242
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
243
|
|
|
243
244
|
```shell
|
|
244
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
245
246
|
```
|
|
246
247
|
|
|
247
248
|
where:
|
|
@@ -251,7 +252,8 @@ where:
|
|
|
251
252
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
253
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
254
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
255
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
256
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
255
257
|
|
|
256
258
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
259
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
291
293
|
with open('path/to/config.cfg', 'r') as file:
|
|
292
294
|
config_template = file.read()
|
|
293
295
|
|
|
294
|
-
#
|
|
296
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
295
297
|
def power_10(x):
|
|
296
|
-
return 10.0 ** x
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
300
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
301
|
-
]
|
|
298
|
+
return 10.0 ** x
|
|
299
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
300
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
302
301
|
|
|
303
302
|
# build the tuner and tune
|
|
304
303
|
tuning = JaxParameterTuning(env=env,
|
|
305
|
-
config_template=config_template,
|
|
306
|
-
|
|
307
|
-
online=False,
|
|
308
|
-
eval_trials=trials,
|
|
309
|
-
num_workers=workers,
|
|
310
|
-
gp_iters=iters)
|
|
304
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
305
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
311
306
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
312
307
|
```
|
|
313
308
|
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
2
|
-
pyRDDLGym_jax/entry_point.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
|
|
2
|
+
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
7
|
-
pyRDDLGym_jax/core/simulator.py,sha256=
|
|
8
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
|
|
7
|
+
pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
|
|
8
|
+
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
9
9
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
10
10
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
12
12
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
14
14
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
16
|
-
pyRDDLGym_jax/examples/run_scipy.py,sha256=
|
|
17
|
-
pyRDDLGym_jax/examples/run_tune.py,sha256=
|
|
15
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
|
|
16
|
+
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
|
|
18
18
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
19
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
20
20
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
|
|
@@ -38,12 +38,12 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
|
|
|
38
38
|
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
|
|
39
39
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
|
|
40
40
|
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
|
|
41
|
-
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=
|
|
42
|
-
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=
|
|
43
|
-
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
41
|
+
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
42
|
+
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
43
|
+
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
44
|
+
pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
|
|
46
|
+
pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyrddlgym_jax-2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|