pyRDDLGym-jax 2.4__py3-none-any.whl → 2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +23 -10
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +317 -99
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +25 -10
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +1 -1
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/METADATA +17 -30
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/RECORD +19 -18
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -19,10 +19,12 @@
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
|
-
|
|
22
|
+
import numpy as np
|
|
23
|
+
from typing import Dict, Optional, Union
|
|
23
24
|
|
|
24
25
|
import jax
|
|
25
26
|
|
|
27
|
+
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
26
28
|
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
27
29
|
from pyRDDLGym.core.debug.exception import (
|
|
28
30
|
RDDLActionPreconditionNotSatisfiedError,
|
|
@@ -35,7 +37,7 @@ from pyRDDLGym.core.simulator import RDDLSimulator
|
|
|
35
37
|
|
|
36
38
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
37
39
|
|
|
38
|
-
Args = Dict[str, Value]
|
|
40
|
+
Args = Dict[str, Union[np.ndarray, Value]]
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class JaxRDDLSimulator(RDDLSimulator):
|
|
@@ -45,6 +47,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
45
47
|
raise_error: bool=True,
|
|
46
48
|
logger: Optional[Logger]=None,
|
|
47
49
|
keep_tensors: bool=False,
|
|
50
|
+
objects_as_strings: bool=True,
|
|
48
51
|
**compiler_args) -> None:
|
|
49
52
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
50
53
|
|
|
@@ -57,6 +60,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
57
60
|
:param logger: to log information about compilation to file
|
|
58
61
|
:param keep_tensors: whether the sampler takes actions and
|
|
59
62
|
returns state in numpy array form
|
|
63
|
+
param objects_as_strings: whether to return object values as strings (defaults
|
|
64
|
+
to integer indices if False)
|
|
60
65
|
:param **compiler_args: keyword arguments to pass to the Jax compiler
|
|
61
66
|
'''
|
|
62
67
|
if key is None:
|
|
@@ -67,7 +72,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
67
72
|
|
|
68
73
|
# generate direct sampling with default numpy RNG and operations
|
|
69
74
|
super(JaxRDDLSimulator, self).__init__(
|
|
70
|
-
rddl, logger=logger,
|
|
75
|
+
rddl, logger=logger,
|
|
76
|
+
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
|
|
71
77
|
|
|
72
78
|
def seed(self, seed: int) -> None:
|
|
73
79
|
super(JaxRDDLSimulator, self).seed(seed)
|
|
@@ -84,11 +90,11 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
84
90
|
self.levels = compiled.levels
|
|
85
91
|
self.traced = compiled.traced
|
|
86
92
|
|
|
87
|
-
self.invariants = jax.tree_map(jax.jit, compiled.invariants)
|
|
88
|
-
self.preconds = jax.tree_map(jax.jit, compiled.preconditions)
|
|
89
|
-
self.terminals = jax.tree_map(jax.jit, compiled.terminations)
|
|
93
|
+
self.invariants = jax.tree_util.tree_map(jax.jit, compiled.invariants)
|
|
94
|
+
self.preconds = jax.tree_util.tree_map(jax.jit, compiled.preconditions)
|
|
95
|
+
self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
|
|
90
96
|
self.reward = jax.jit(compiled.reward)
|
|
91
|
-
jax_cpfs = jax.tree_map(jax.jit, compiled.cpfs)
|
|
97
|
+
jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
|
|
92
98
|
self.model_params = compiled.model_params
|
|
93
99
|
|
|
94
100
|
# level analysis
|
|
@@ -139,7 +145,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
139
145
|
|
|
140
146
|
def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
|
|
141
147
|
'''Throws an exception if the action preconditions are not satisfied.'''
|
|
142
|
-
actions = self._process_actions(actions)
|
|
143
148
|
subs = self.subs
|
|
144
149
|
subs.update(actions)
|
|
145
150
|
|
|
@@ -180,7 +185,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
180
185
|
'''
|
|
181
186
|
rddl = self.rddl
|
|
182
187
|
keep_tensors = self.keep_tensors
|
|
183
|
-
actions = self._process_actions(actions)
|
|
184
188
|
subs = self.subs
|
|
185
189
|
subs.update(actions)
|
|
186
190
|
|
|
@@ -196,20 +200,40 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
196
200
|
# update state
|
|
197
201
|
self.state = {}
|
|
198
202
|
for (state, next_state) in rddl.next_state.items():
|
|
203
|
+
|
|
204
|
+
# set state = state' for the next epoch
|
|
199
205
|
subs[state] = subs[next_state]
|
|
206
|
+
|
|
207
|
+
# convert object integer to string representation
|
|
208
|
+
state_values = subs[state]
|
|
209
|
+
if self.objects_as_strings:
|
|
210
|
+
ptype = rddl.variable_ranges[state]
|
|
211
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
212
|
+
state_values = rddl.index_to_object_string_array(ptype, state_values)
|
|
213
|
+
|
|
214
|
+
# optional grounding of state dictionary
|
|
200
215
|
if keep_tensors:
|
|
201
|
-
self.state[state] =
|
|
216
|
+
self.state[state] = state_values
|
|
202
217
|
else:
|
|
203
|
-
self.state.update(rddl.ground_var_with_values(state,
|
|
218
|
+
self.state.update(rddl.ground_var_with_values(state, state_values))
|
|
204
219
|
|
|
205
220
|
# update observation
|
|
206
221
|
if self._pomdp:
|
|
207
222
|
obs = {}
|
|
208
223
|
for var in rddl.observ_fluents:
|
|
224
|
+
|
|
225
|
+
# convert object integer to string representation
|
|
226
|
+
obs_values = subs[var]
|
|
227
|
+
if self.objects_as_strings:
|
|
228
|
+
ptype = rddl.variable_ranges[var]
|
|
229
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
230
|
+
obs_values = rddl.index_to_object_string_array(ptype, obs_values)
|
|
231
|
+
|
|
232
|
+
# optional grounding of observ-fluent dictionary
|
|
209
233
|
if keep_tensors:
|
|
210
|
-
obs[var] =
|
|
234
|
+
obs[var] = obs_values
|
|
211
235
|
else:
|
|
212
|
-
obs.update(rddl.ground_var_with_values(var,
|
|
236
|
+
obs.update(rddl.ground_var_with_values(var, obs_values))
|
|
213
237
|
else:
|
|
214
238
|
obs = self.state
|
|
215
239
|
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -371,16 +371,30 @@ class JaxParameterTuning:
|
|
|
371
371
|
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
372
372
|
print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
|
|
373
373
|
|
|
374
|
-
def tune(self, key: int,
|
|
375
|
-
|
|
374
|
+
def tune(self, key: int,
|
|
375
|
+
log_file: Optional[str]=None,
|
|
376
|
+
show_dashboard: bool=False,
|
|
377
|
+
print_hyperparams: bool=False) -> ParameterValues:
|
|
378
|
+
'''Tunes the hyper-parameters for Jax planner, returns the best found.
|
|
376
379
|
|
|
377
|
-
|
|
380
|
+
:param key: RNG key to seed the hyper-parameter optimizer
|
|
381
|
+
:param log_file: optional path to file where tuning progress will be saved
|
|
382
|
+
:param show_dashboard: whether to display tuning results in a dashboard
|
|
383
|
+
:param print_hyperparams: whether to print a hyper-parameter summary of the
|
|
384
|
+
optimizer
|
|
385
|
+
'''
|
|
378
386
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
387
|
+
if self.verbose:
|
|
388
|
+
print(JaxBackpropPlanner.summarize_system())
|
|
389
|
+
if print_hyperparams:
|
|
390
|
+
print(self.summarize_hyperparameters())
|
|
383
391
|
|
|
392
|
+
# clear and prepare output file
|
|
393
|
+
if log_file is not None:
|
|
394
|
+
with open(log_file, 'w', newline='') as file:
|
|
395
|
+
writer = csv.writer(file)
|
|
396
|
+
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
397
|
+
|
|
384
398
|
# create a dash-board for visualizing experiment runs
|
|
385
399
|
if show_dashboard and JaxPlannerDashboard is not None:
|
|
386
400
|
dashboard = JaxPlannerDashboard()
|
|
@@ -519,9 +533,10 @@ class JaxParameterTuning:
|
|
|
519
533
|
self.tune_optimizer(optimizer)
|
|
520
534
|
|
|
521
535
|
# write results of all processes in current iteration to file
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
536
|
+
if log_file is not None:
|
|
537
|
+
with open(log_file, 'a', newline='') as file:
|
|
538
|
+
writer = csv.writer(file)
|
|
539
|
+
writer.writerows(rows)
|
|
525
540
|
|
|
526
541
|
# update the dashboard tuning
|
|
527
542
|
if show_dashboard:
|
pyRDDLGym_jax/entry_point.py
CHANGED
|
@@ -2,24 +2,56 @@ import argparse
|
|
|
2
2
|
|
|
3
3
|
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
4
|
|
|
5
|
+
EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
|
|
6
|
+
|
|
5
7
|
def main():
|
|
6
|
-
parser = argparse.ArgumentParser(
|
|
8
|
+
parser = argparse.ArgumentParser(prog='jaxplan',
|
|
9
|
+
description="command line parser for the jaxplan planner",
|
|
10
|
+
epilog=EPILOG)
|
|
7
11
|
subparsers = parser.add_subparsers(dest="jaxplan", required=True)
|
|
8
12
|
|
|
9
13
|
# planning
|
|
10
|
-
parser_plan = subparsers.add_parser("plan",
|
|
11
|
-
|
|
14
|
+
parser_plan = subparsers.add_parser("plan",
|
|
15
|
+
help="execute jaxplan on a specified RDDL problem",
|
|
16
|
+
epilog=EPILOG)
|
|
17
|
+
parser_plan.add_argument('domain', type=str,
|
|
18
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
19
|
+
parser_plan.add_argument('instance', type=str,
|
|
20
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
21
|
+
parser_plan.add_argument('method', type=str,
|
|
22
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
23
|
+
parser_plan.add_argument('-e', '--episodes', type=int, required=False, default=1,
|
|
24
|
+
help='number of training or evaluation episodes')
|
|
12
25
|
|
|
13
26
|
# tuning
|
|
14
|
-
parser_tune = subparsers.add_parser("tune",
|
|
15
|
-
|
|
27
|
+
parser_tune = subparsers.add_parser("tune",
|
|
28
|
+
help="tune jaxplan on a specified RDDL problem",
|
|
29
|
+
epilog=EPILOG)
|
|
30
|
+
parser_tune.add_argument('domain', type=str,
|
|
31
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
32
|
+
parser_tune.add_argument('instance', type=str,
|
|
33
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
34
|
+
parser_tune.add_argument('method', type=str,
|
|
35
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
36
|
+
parser_tune.add_argument('-t', '--trials', type=int, required=False, default=5,
|
|
37
|
+
help='number of evaluation rollouts per hyper-parameter choice')
|
|
38
|
+
parser_tune.add_argument('-i', '--iters', type=int, required=False, default=20,
|
|
39
|
+
help='number of iterations of bayesian optimization')
|
|
40
|
+
parser_tune.add_argument('-w', '--workers', type=int, required=False, default=4,
|
|
41
|
+
help='number of parallel hyper-parameters to evaluate per iteration')
|
|
42
|
+
parser_tune.add_argument('-d', '--dashboard', type=bool, required=False, default=False,
|
|
43
|
+
help='show the dashboard')
|
|
44
|
+
parser_tune.add_argument('-f', '--filepath', type=str, required=False, default='',
|
|
45
|
+
help='where to save the config file of the best hyper-parameters')
|
|
16
46
|
|
|
17
47
|
# dispatch
|
|
18
48
|
args = parser.parse_args()
|
|
19
49
|
if args.jaxplan == "plan":
|
|
20
|
-
run_plan.
|
|
50
|
+
run_plan.main(args.domain, args.instance, args.method, args.episodes)
|
|
21
51
|
elif args.jaxplan == "tune":
|
|
22
|
-
run_tune.
|
|
52
|
+
run_tune.main(args.domain, args.instance, args.method,
|
|
53
|
+
args.trials, args.iters, args.workers, args.dashboard,
|
|
54
|
+
args.filepath)
|
|
23
55
|
else:
|
|
24
56
|
parser.print_help()
|
|
25
57
|
|
|
@@ -26,7 +26,7 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def main(domain, instance, method, episodes=1):
|
|
29
|
+
def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
30
30
|
|
|
31
31
|
# set up the environment
|
|
32
32
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -36,7 +36,9 @@ def power_10(x):
|
|
|
36
36
|
return 10.0 ** x
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def main(domain, instance, method
|
|
39
|
+
def main(domain: str, instance: str, method: str,
|
|
40
|
+
trials: int=5, iters: int=20, workers: int=4, dashboard: bool=False,
|
|
41
|
+
filepath: str='') -> None:
|
|
40
42
|
|
|
41
43
|
# set up the environment
|
|
42
44
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -68,6 +70,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
68
70
|
tuning.tune(key=42,
|
|
69
71
|
log_file=f'gp_{method}_{domain}_{instance}.csv',
|
|
70
72
|
show_dashboard=dashboard)
|
|
73
|
+
if filepath is not None and filepath:
|
|
74
|
+
with open(filepath, "w") as file:
|
|
75
|
+
file.write(tuning.best_config)
|
|
71
76
|
|
|
72
77
|
# evaluate the agent on the best parameters
|
|
73
78
|
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
@@ -80,7 +85,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
80
85
|
|
|
81
86
|
def run_from_args(args):
|
|
82
87
|
if len(args) < 3:
|
|
83
|
-
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
|
|
88
|
+
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>] [<filepath>]')
|
|
84
89
|
exit(1)
|
|
85
90
|
if args[2] not in ['drp', 'slp', 'replan']:
|
|
86
91
|
print('<method> in [drp, slp, replan]')
|
|
@@ -90,6 +95,7 @@ def run_from_args(args):
|
|
|
90
95
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
91
96
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
92
97
|
if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
|
|
98
|
+
if len(args) >= 8: kwargs['filepath'] = bool(args[7])
|
|
93
99
|
main(**kwargs)
|
|
94
100
|
|
|
95
101
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.3
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -39,6 +39,7 @@ Dynamic: description
|
|
|
39
39
|
Dynamic: description-content-type
|
|
40
40
|
Dynamic: home-page
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: provides-extra
|
|
43
44
|
Dynamic: requires-dist
|
|
44
45
|
Dynamic: requires-python
|
|
@@ -54,7 +55,7 @@ Dynamic: summary
|
|
|
54
55
|
|
|
55
56
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
56
57
|
|
|
57
|
-
**pyRDDLGym-jax (
|
|
58
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
58
59
|
|
|
59
60
|
Purpose:
|
|
60
61
|
|
|
@@ -83,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
83
84
|
|
|
84
85
|
> [!NOTE]
|
|
85
86
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
86
|
-
> If you find it is not making
|
|
87
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
87
88
|
|
|
88
89
|
## Installation
|
|
89
90
|
|
|
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
116
117
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
117
118
|
|
|
118
119
|
```shell
|
|
119
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
120
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
120
121
|
```
|
|
121
122
|
|
|
122
123
|
where:
|
|
@@ -219,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
219
220
|
## JaxPlan Dashboard
|
|
220
221
|
|
|
221
222
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
222
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
223
|
-
|
|
224
|
-
<p align="middle">
|
|
225
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
226
|
-
</p>
|
|
227
|
-
|
|
228
|
-
To run the dashboard, add the following entry to your config file:
|
|
223
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
229
224
|
|
|
230
225
|
```ini
|
|
231
226
|
...
|
|
@@ -234,14 +229,12 @@ dashboard=True
|
|
|
234
229
|
...
|
|
235
230
|
```
|
|
236
231
|
|
|
237
|
-
More documentation about this and other new features will be coming soon.
|
|
238
|
-
|
|
239
232
|
## Tuning the Planner
|
|
240
233
|
|
|
241
234
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
235
|
|
|
243
236
|
```shell
|
|
244
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
237
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
245
238
|
```
|
|
246
239
|
|
|
247
240
|
where:
|
|
@@ -251,7 +244,8 @@ where:
|
|
|
251
244
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
245
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
246
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
247
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
248
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
255
249
|
|
|
256
250
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
251
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -291,23 +285,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
291
285
|
with open('path/to/config.cfg', 'r') as file:
|
|
292
286
|
config_template = file.read()
|
|
293
287
|
|
|
294
|
-
#
|
|
288
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
295
289
|
def power_10(x):
|
|
296
|
-
return 10.0 ** x
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
300
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
301
|
-
]
|
|
290
|
+
return 10.0 ** x
|
|
291
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
292
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
302
293
|
|
|
303
294
|
# build the tuner and tune
|
|
304
295
|
tuning = JaxParameterTuning(env=env,
|
|
305
|
-
config_template=config_template,
|
|
306
|
-
|
|
307
|
-
online=False,
|
|
308
|
-
eval_trials=trials,
|
|
309
|
-
num_workers=workers,
|
|
310
|
-
gp_iters=iters)
|
|
296
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
297
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
311
298
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
312
299
|
```
|
|
313
300
|
|
|
@@ -1,20 +1,21 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
2
|
-
pyRDDLGym_jax/entry_point.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=VUmQViJtwUg1JGcgXlmNm0fE3Njyruyt_76c16R-LTo,19
|
|
2
|
+
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/
|
|
7
|
-
pyRDDLGym_jax/core/
|
|
8
|
-
pyRDDLGym_jax/core/
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=Bpgfw4nqRFqiTju7ioR0B0Dhp3wMvk-9LmTRpMmLIOc,83457
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
|
|
6
|
+
pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
|
|
7
|
+
pyRDDLGym_jax/core/planner.py,sha256=a684ss5TAkJ-P2SEbZA90FSpDwFxHwRoaLtbRIBspAA,146450
|
|
8
|
+
pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
|
|
9
|
+
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
9
10
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
10
11
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
12
|
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
12
13
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
14
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
14
15
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
16
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
|
|
16
17
|
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
|
-
pyRDDLGym_jax/examples/run_tune.py,sha256=
|
|
18
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
|
|
18
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
19
20
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
20
21
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
|
|
@@ -38,12 +39,12 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
|
|
|
38
39
|
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
|
|
39
40
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
|
|
40
41
|
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
|
|
41
|
-
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=
|
|
42
|
-
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=
|
|
43
|
-
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
42
|
+
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
43
|
+
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
44
|
+
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
45
|
+
pyrddlgym_jax-2.6.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
46
|
+
pyrddlgym_jax-2.6.dist-info/METADATA,sha256=1gY3EPRHKMVeZYYgq4DCqWvw3Q1Ak5XVYRaIO2UlQXc,16770
|
|
47
|
+
pyrddlgym_jax-2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
+
pyrddlgym_jax-2.6.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
49
|
+
pyrddlgym_jax-2.6.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
50
|
+
pyrddlgym_jax-2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|