pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +832 -530
- pyRDDLGym_jax/core/planner.py +422 -474
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +390 -584
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -29
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -1,82 +1,94 @@
|
|
|
1
|
-
from copy import deepcopy
|
|
2
1
|
import csv
|
|
3
2
|
import datetime
|
|
4
|
-
|
|
3
|
+
import threading
|
|
4
|
+
import multiprocessing
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
-
from typing import Any, Callable, Dict, Optional, Tuple
|
|
7
|
+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
8
8
|
import warnings
|
|
9
9
|
warnings.filterwarnings("ignore")
|
|
10
10
|
|
|
11
|
+
from sklearn.gaussian_process.kernels import Matern, ConstantKernel
|
|
11
12
|
from bayes_opt import BayesianOptimization
|
|
12
|
-
from bayes_opt.
|
|
13
|
+
from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
|
|
13
14
|
import jax
|
|
14
15
|
import numpy as np
|
|
15
16
|
|
|
16
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
17
17
|
from pyRDDLGym.core.env import RDDLEnv
|
|
18
18
|
|
|
19
19
|
from pyRDDLGym_jax.core.planner import (
|
|
20
20
|
JaxBackpropPlanner,
|
|
21
|
-
JaxStraightLinePlan,
|
|
22
|
-
JaxDeepReactivePolicy,
|
|
23
21
|
JaxOfflineController,
|
|
24
|
-
JaxOnlineController
|
|
22
|
+
JaxOnlineController,
|
|
23
|
+
load_config_from_string
|
|
25
24
|
)
|
|
26
25
|
|
|
26
|
+
# try to load the dash board
|
|
27
|
+
try:
|
|
28
|
+
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
29
|
+
except Exception:
|
|
30
|
+
raise_warning('Failed to load the dashboard visualization tool: '
|
|
31
|
+
'please make sure you have installed the required packages.',
|
|
32
|
+
'red')
|
|
33
|
+
traceback.print_exc()
|
|
34
|
+
JaxPlannerDashboard = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Hyperparameter:
|
|
38
|
+
'''A generic hyper-parameter of the planner that can be tuned.'''
|
|
39
|
+
|
|
40
|
+
def __init__(self, tag: str, lower_bound: float, upper_bound: float,
|
|
41
|
+
search_to_config_map: Callable) -> None:
|
|
42
|
+
self.tag = tag
|
|
43
|
+
self.lower_bound = lower_bound
|
|
44
|
+
self.upper_bound = upper_bound
|
|
45
|
+
self.search_to_config_map = search_to_config_map
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
return (f'{self.search_to_config_map.__name__} '
|
|
49
|
+
f': [{self.lower_bound}, {self.upper_bound}] -> {self.tag}')
|
|
50
|
+
|
|
51
|
+
|
|
27
52
|
Kwargs = Dict[str, Any]
|
|
53
|
+
ParameterValues = Dict[str, Any]
|
|
54
|
+
Hyperparameters = Iterable[Hyperparameter]
|
|
55
|
+
|
|
56
|
+
COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
|
|
28
57
|
|
|
29
58
|
|
|
30
|
-
# ===============================================================================
|
|
31
|
-
#
|
|
32
|
-
# GENERIC TUNING MODULE
|
|
33
|
-
#
|
|
34
|
-
# Currently contains three implementations:
|
|
35
|
-
# 1. straight line plan
|
|
36
|
-
# 2. re-planning
|
|
37
|
-
# 3. deep reactive policies
|
|
38
|
-
#
|
|
39
|
-
# ===============================================================================
|
|
40
59
|
class JaxParameterTuning:
|
|
41
60
|
'''A general-purpose class for tuning a Jax planner.'''
|
|
42
61
|
|
|
43
62
|
def __init__(self, env: RDDLEnv,
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
timeout_tuning: float=np.inf,
|
|
63
|
+
config_template: str,
|
|
64
|
+
hyperparams: Hyperparameters,
|
|
65
|
+
online: bool,
|
|
48
66
|
eval_trials: int=5,
|
|
49
67
|
verbose: bool=True,
|
|
50
|
-
|
|
51
|
-
plan_kwargs: Optional[Kwargs]=None,
|
|
68
|
+
timeout_tuning: float=np.inf,
|
|
52
69
|
pool_context: str='spawn',
|
|
53
70
|
num_workers: int=1,
|
|
54
71
|
poll_frequency: float=0.2,
|
|
55
72
|
gp_iters: int=25,
|
|
56
|
-
acquisition: Optional[
|
|
73
|
+
acquisition: Optional[AcquisitionFunction]=None,
|
|
57
74
|
gp_init_kwargs: Optional[Kwargs]=None,
|
|
58
75
|
gp_params: Optional[Kwargs]=None) -> None:
|
|
59
76
|
'''Creates a new instance for tuning hyper-parameters for Jax planners
|
|
60
77
|
on the given RDDL domain and instance.
|
|
61
78
|
|
|
62
79
|
:param env: the RDDLEnv describing the MDP to optimize
|
|
63
|
-
:param
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
:param timeout_training: the maximum amount of time to spend training per
|
|
70
|
-
trial/decision step (in seconds)
|
|
80
|
+
:param config_template: base configuration file content to tune: regex
|
|
81
|
+
matches are specified directly in the config and map to keys in the
|
|
82
|
+
hyperparams_dict field
|
|
83
|
+
:param hyperparams: list of hyper-parameters to regex replace in the
|
|
84
|
+
config template during tuning
|
|
85
|
+
:param online: whether the planner is optimized online or offline
|
|
71
86
|
:param timeout_tuning: the maximum amount of time to spend tuning
|
|
72
87
|
hyperparameters in general (in seconds)
|
|
73
88
|
:param eval_trials: how many trials to perform independent training
|
|
74
89
|
in order to estimate the return for each set of hyper-parameters
|
|
75
90
|
:param verbose: whether to print intermediate results of tuning
|
|
76
|
-
:param
|
|
77
|
-
:param plan_kwargs: additional arguments to feed to the plan/policy
|
|
78
|
-
:param pool_context: context for multiprocessing pool (defaults to
|
|
79
|
-
"spawn")
|
|
91
|
+
:param pool_context: context for multiprocessing pool (default "spawn")
|
|
80
92
|
:param num_workers: how many points to evaluate in parallel
|
|
81
93
|
:param poll_frequency: how often (in seconds) to poll for completed
|
|
82
94
|
jobs, necessary if num_workers > 1
|
|
@@ -86,21 +98,20 @@ class JaxParameterTuning:
|
|
|
86
98
|
during initialization
|
|
87
99
|
:param gp_params: additional parameters to feed to Bayesian optimizer
|
|
88
100
|
after initialization optimization
|
|
89
|
-
'''
|
|
90
|
-
|
|
101
|
+
'''
|
|
102
|
+
# objective parameters
|
|
91
103
|
self.env = env
|
|
104
|
+
self.config_template = config_template
|
|
105
|
+
hyperparams_dict = {hyper_param.tag: hyper_param
|
|
106
|
+
for hyper_param in hyperparams
|
|
107
|
+
if hyper_param.tag in config_template}
|
|
92
108
|
self.hyperparams_dict = hyperparams_dict
|
|
93
|
-
self.
|
|
94
|
-
self.timeout_training = timeout_training
|
|
95
|
-
self.timeout_tuning = timeout_tuning
|
|
109
|
+
self.online = online
|
|
96
110
|
self.eval_trials = eval_trials
|
|
97
111
|
self.verbose = verbose
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
self.
|
|
101
|
-
if plan_kwargs is None:
|
|
102
|
-
plan_kwargs = {}
|
|
103
|
-
self.plan_kwargs = plan_kwargs
|
|
112
|
+
|
|
113
|
+
# Bayesian parameters
|
|
114
|
+
self.timeout_tuning = timeout_tuning
|
|
104
115
|
self.pool_context = pool_context
|
|
105
116
|
self.num_workers = num_workers
|
|
106
117
|
self.poll_frequency = poll_frequency
|
|
@@ -109,21 +120,33 @@ class JaxParameterTuning:
|
|
|
109
120
|
gp_init_kwargs = {}
|
|
110
121
|
self.gp_init_kwargs = gp_init_kwargs
|
|
111
122
|
if gp_params is None:
|
|
112
|
-
gp_params = {'n_restarts_optimizer':
|
|
123
|
+
gp_params = {'n_restarts_optimizer': 25,
|
|
124
|
+
'kernel': self.make_default_kernel()}
|
|
113
125
|
self.gp_params = gp_params
|
|
114
|
-
|
|
115
|
-
# create acquisition function
|
|
116
|
-
self.acq_args = None
|
|
117
126
|
if acquisition is None:
|
|
118
127
|
num_samples = self.gp_iters * self.num_workers
|
|
119
|
-
acquisition
|
|
128
|
+
acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
|
|
120
129
|
self.acquisition = acquisition
|
|
121
130
|
|
|
131
|
+
@staticmethod
|
|
132
|
+
def make_default_kernel():
|
|
133
|
+
weight1 = ConstantKernel(1.0, (0.01, 100.0))
|
|
134
|
+
weight2 = ConstantKernel(1.0, (0.01, 100.0))
|
|
135
|
+
weight3 = ConstantKernel(1.0, (0.01, 100.0))
|
|
136
|
+
kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
|
|
137
|
+
kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
|
|
138
|
+
kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
|
|
139
|
+
return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
|
|
140
|
+
|
|
122
141
|
def summarize_hyperparameters(self) -> None:
|
|
142
|
+
hyper_params_table = []
|
|
143
|
+
for (_, param) in self.hyperparams_dict.items():
|
|
144
|
+
hyper_params_table.append(f' {str(param)}')
|
|
145
|
+
hyper_params_table = '\n'.join(hyper_params_table)
|
|
123
146
|
print(f'hyperparameter optimizer parameters:\n'
|
|
124
|
-
f' tuned_hyper_parameters
|
|
147
|
+
f' tuned_hyper_parameters =\n{hyper_params_table}\n'
|
|
125
148
|
f' initialization_args ={self.gp_init_kwargs}\n'
|
|
126
|
-
f'
|
|
149
|
+
f' gp_params ={self.gp_params}\n'
|
|
127
150
|
f' tuning_iterations ={self.gp_iters}\n'
|
|
128
151
|
f' tuning_timeout ={self.timeout_tuning}\n'
|
|
129
152
|
f' tuning_batch_size ={self.num_workers}\n'
|
|
@@ -131,154 +154,348 @@ class JaxParameterTuning:
|
|
|
131
154
|
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
132
155
|
f'meta-objective parameters:\n'
|
|
133
156
|
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
134
|
-
f'
|
|
135
|
-
f' planning_timeout_per_trial={self.timeout_training}\n'
|
|
136
|
-
f' acquisition_fn ={type(self.acquisition).__name__}')
|
|
137
|
-
if self.acq_args is not None:
|
|
138
|
-
print(f'using default acquisition function:\n'
|
|
139
|
-
f' utility_kind ={self.acq_args[0]}\n'
|
|
140
|
-
f' initial_kappa={self.acq_args[1]}\n'
|
|
141
|
-
f' kappa_decay ={self.acq_args[2]}')
|
|
157
|
+
f' acquisition_fn ={self.acquisition}')
|
|
142
158
|
|
|
143
159
|
@staticmethod
|
|
144
|
-
def
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
kind='ucb',
|
|
160
|
+
def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
|
|
161
|
+
kappa1: float=10.0, kappa2: float=1.0) -> UpperConfidenceBound:
|
|
162
|
+
acq_fn = UpperConfidenceBound(
|
|
148
163
|
kappa=kappa1,
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
return
|
|
164
|
+
exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
|
|
165
|
+
exploration_decay_delay=n_delay_samples
|
|
166
|
+
)
|
|
167
|
+
return acq_fn
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def search_to_config_params(hyper_params: Hyperparameters,
|
|
171
|
+
params: ParameterValues) -> ParameterValues:
|
|
172
|
+
config_params = {
|
|
173
|
+
tag: param.search_to_config_map(params[tag])
|
|
174
|
+
for (tag, param) in hyper_params.items()
|
|
175
|
+
}
|
|
176
|
+
return config_params
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def config_from_template(config_template: str,
|
|
180
|
+
config_params: ParameterValues) -> str:
|
|
181
|
+
config_string = config_template
|
|
182
|
+
for (tag, param_value) in config_params.items():
|
|
183
|
+
config_string = config_string.replace(tag, str(param_value))
|
|
184
|
+
return config_string
|
|
153
185
|
|
|
154
|
-
|
|
155
|
-
|
|
186
|
+
@property
|
|
187
|
+
def best_config(self) -> str:
|
|
188
|
+
return self.config_from_template(self.config_template, self.best_params)
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def queue_listener(queue, dashboard):
|
|
192
|
+
while True:
|
|
193
|
+
args = queue.get()
|
|
194
|
+
if args is None:
|
|
195
|
+
break
|
|
196
|
+
elif len(args) == 2:
|
|
197
|
+
dashboard.update_experiment(*args)
|
|
198
|
+
else:
|
|
199
|
+
dashboard.register_experiment(*args)
|
|
156
200
|
|
|
157
201
|
@staticmethod
|
|
158
|
-
def
|
|
159
|
-
|
|
202
|
+
def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
203
|
+
verbose, viz, queue):
|
|
204
|
+
average_reward = 0.0
|
|
205
|
+
for trial in range(num_trials):
|
|
206
|
+
key, subkey = jax.random.split(key)
|
|
207
|
+
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
208
|
+
if queue is not None:
|
|
209
|
+
queue.put((
|
|
210
|
+
experiment_id,
|
|
211
|
+
JaxPlannerDashboard.get_planner_info(planner),
|
|
212
|
+
subkey[0],
|
|
213
|
+
viz
|
|
214
|
+
))
|
|
215
|
+
|
|
216
|
+
# train the policy
|
|
217
|
+
callback = None
|
|
218
|
+
for callback in planner.optimize_generator(key=subkey, **train_args):
|
|
219
|
+
if queue is not None and queue.empty():
|
|
220
|
+
queue.put((experiment_id, callback))
|
|
221
|
+
best_params = None if callback is None else callback['best_params']
|
|
222
|
+
|
|
223
|
+
# evaluate the policy in the real environment
|
|
224
|
+
policy = JaxOfflineController(
|
|
225
|
+
planner=planner, key=subkey, tqdm_position=index,
|
|
226
|
+
params=best_params, train_on_reset=False)
|
|
227
|
+
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
228
|
+
|
|
229
|
+
# update average reward
|
|
230
|
+
if verbose:
|
|
231
|
+
iters = None if callback is None else callback['iteration']
|
|
232
|
+
print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
|
|
233
|
+
f'reward={total_reward:.6f}, iters={iters}', flush=True)
|
|
234
|
+
average_reward += total_reward / num_trials
|
|
235
|
+
|
|
236
|
+
if verbose:
|
|
237
|
+
print(f'[{index}] average reward={average_reward:.6f}', flush=True)
|
|
238
|
+
return average_reward
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
def online_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
242
|
+
verbose, viz, queue):
|
|
243
|
+
average_reward = 0.0
|
|
244
|
+
for trial in range(num_trials):
|
|
245
|
+
key, subkey = jax.random.split(key)
|
|
246
|
+
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
247
|
+
if queue is not None:
|
|
248
|
+
queue.put((
|
|
249
|
+
experiment_id,
|
|
250
|
+
JaxPlannerDashboard.get_planner_info(planner),
|
|
251
|
+
subkey[0],
|
|
252
|
+
viz
|
|
253
|
+
))
|
|
254
|
+
|
|
255
|
+
# initialize the online policy
|
|
256
|
+
policy = JaxOnlineController(
|
|
257
|
+
planner=planner, key=subkey, tqdm_position=index, **train_args)
|
|
258
|
+
|
|
259
|
+
# evaluate the policy in the real environment
|
|
260
|
+
total_reward = 0.0
|
|
261
|
+
callback = None
|
|
262
|
+
state, _ = env.reset(seed=np.array(subkey)[0])
|
|
263
|
+
elapsed_time = 0.0
|
|
264
|
+
for step in range(env.horizon):
|
|
265
|
+
action = policy.sample_action(state)
|
|
266
|
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
|
267
|
+
total_reward += reward
|
|
268
|
+
done = terminated or truncated
|
|
269
|
+
state = next_state
|
|
270
|
+
callback = policy.callback
|
|
271
|
+
elapsed_time += callback['elapsed_time']
|
|
272
|
+
callback['iteration'] = step
|
|
273
|
+
callback['progress'] = int(100 * (step + 1.) / env.horizon)
|
|
274
|
+
callback['elapsed_time'] = elapsed_time
|
|
275
|
+
if queue is not None and queue.empty():
|
|
276
|
+
queue.put((experiment_id, callback))
|
|
277
|
+
if done:
|
|
278
|
+
break
|
|
279
|
+
|
|
280
|
+
# update average reward
|
|
281
|
+
if verbose:
|
|
282
|
+
iters = None if callback is None else callback['iteration']
|
|
283
|
+
print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
|
|
284
|
+
f'reward={total_reward:.6f}, iters={iters}', flush=True)
|
|
285
|
+
average_reward += total_reward / num_trials
|
|
286
|
+
|
|
287
|
+
if verbose:
|
|
288
|
+
print(f'[{index}] average reward={average_reward:.6f}', flush=True)
|
|
289
|
+
return average_reward
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def objective_function(params: ParameterValues,
|
|
293
|
+
key: jax.random.PRNGKey,
|
|
294
|
+
index: int,
|
|
295
|
+
iteration: int,
|
|
296
|
+
kwargs: Kwargs,
|
|
297
|
+
queue: object) -> Tuple[ParameterValues, float, int, int]:
|
|
298
|
+
'''A pickleable objective function to evaluate a single hyper-parameter
|
|
299
|
+
configuration.'''
|
|
300
|
+
|
|
301
|
+
hyperparams_dict = kwargs['hyperparams_dict']
|
|
302
|
+
config_template = kwargs['config_template']
|
|
303
|
+
online = kwargs['online']
|
|
304
|
+
domain = kwargs['domain']
|
|
305
|
+
instance = kwargs['instance']
|
|
306
|
+
num_trials = kwargs['eval_trials']
|
|
307
|
+
viz = kwargs['viz']
|
|
308
|
+
verbose = kwargs['verbose']
|
|
309
|
+
|
|
310
|
+
# config string substitution and parsing
|
|
311
|
+
config_params = JaxParameterTuning.search_to_config_params(hyperparams_dict, params)
|
|
312
|
+
if verbose:
|
|
313
|
+
config_param_str = ', '.join(
|
|
314
|
+
f'{k}={v}' for (k, v) in config_params.items())
|
|
315
|
+
print(f'[{index}] key={key[0]}, {config_param_str}', flush=True)
|
|
316
|
+
config_string = JaxParameterTuning.config_from_template(config_template, config_params)
|
|
317
|
+
planner_args, _, train_args = load_config_from_string(config_string)
|
|
318
|
+
|
|
319
|
+
# remove keywords that should not be in the tuner
|
|
320
|
+
train_args.pop('dashboard', None)
|
|
321
|
+
|
|
322
|
+
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
323
|
+
env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
|
|
324
|
+
|
|
325
|
+
# run planning algorithm
|
|
326
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
327
|
+
if online:
|
|
328
|
+
average_reward = JaxParameterTuning.online_trials(
|
|
329
|
+
env, planner, train_args, key, iteration, index, num_trials,
|
|
330
|
+
verbose, viz, queue
|
|
331
|
+
)
|
|
332
|
+
else:
|
|
333
|
+
average_reward = JaxParameterTuning.offline_trials(
|
|
334
|
+
env, planner, train_args, key, iteration, index,
|
|
335
|
+
num_trials, verbose, viz, queue
|
|
336
|
+
)
|
|
337
|
+
|
|
160
338
|
pid = os.getpid()
|
|
161
|
-
return
|
|
162
|
-
|
|
163
|
-
def
|
|
164
|
-
|
|
165
|
-
|
|
339
|
+
return params, average_reward, index, pid
|
|
340
|
+
|
|
341
|
+
def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
|
|
342
|
+
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
343
|
+
print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
|
|
344
|
+
|
|
345
|
+
def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
|
|
166
346
|
'''Tunes the hyper-parameters for Jax planner, returns the best found.'''
|
|
347
|
+
|
|
167
348
|
self.summarize_hyperparameters()
|
|
168
349
|
|
|
169
|
-
|
|
350
|
+
# clear and prepare output file
|
|
351
|
+
with open(log_file, 'w', newline='') as file:
|
|
352
|
+
writer = csv.writer(file)
|
|
353
|
+
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
170
354
|
|
|
171
|
-
#
|
|
172
|
-
|
|
173
|
-
|
|
355
|
+
# create a dash-board for visualizing experiment runs
|
|
356
|
+
if show_dashboard:
|
|
357
|
+
dashboard = JaxPlannerDashboard()
|
|
358
|
+
dashboard.launch()
|
|
174
359
|
|
|
360
|
+
# objective function auxiliary data
|
|
361
|
+
obj_kwargs = {
|
|
362
|
+
'hyperparams_dict': self.hyperparams_dict,
|
|
363
|
+
'config_template': self.config_template,
|
|
364
|
+
'online': self.online,
|
|
365
|
+
'domain': self.env.domain_text,
|
|
366
|
+
'instance': self.env.instance_text,
|
|
367
|
+
'eval_trials': self.eval_trials,
|
|
368
|
+
'viz': self.env._visualizer,
|
|
369
|
+
'verbose': self.verbose
|
|
370
|
+
}
|
|
371
|
+
|
|
175
372
|
# create optimizer
|
|
176
373
|
hyperparams_bounds = {
|
|
177
|
-
|
|
178
|
-
for (
|
|
374
|
+
tag: (param.lower_bound, param.upper_bound)
|
|
375
|
+
for (tag, param) in self.hyperparams_dict.items()
|
|
179
376
|
}
|
|
180
377
|
optimizer = BayesianOptimization(
|
|
181
|
-
f=None,
|
|
378
|
+
f=None,
|
|
379
|
+
acquisition_function=self.acquisition,
|
|
182
380
|
pbounds=hyperparams_bounds,
|
|
183
381
|
allow_duplicate_points=True, # to avoid crash
|
|
184
382
|
random_state=np.random.RandomState(key),
|
|
185
383
|
**self.gp_init_kwargs
|
|
186
384
|
)
|
|
187
385
|
optimizer.set_gp_params(**self.gp_params)
|
|
188
|
-
utility = self.acquisition
|
|
189
386
|
|
|
190
387
|
# suggest initial parameters to evaluate
|
|
191
388
|
num_workers = self.num_workers
|
|
192
|
-
|
|
389
|
+
suggested_params, acq_params = [], []
|
|
193
390
|
for _ in range(num_workers):
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
kappas.append(utility.kappa)
|
|
391
|
+
probe = optimizer.suggest()
|
|
392
|
+
suggested_params.append(probe)
|
|
393
|
+
acq_params.append(vars(optimizer.acquisition_function))
|
|
198
394
|
|
|
199
|
-
|
|
200
|
-
filename = self._filename(filename, 'csv')
|
|
201
|
-
with open(filename, 'w', newline='') as file:
|
|
202
|
-
writer = csv.writer(file)
|
|
203
|
-
writer.writerow(
|
|
204
|
-
['pid', 'worker', 'iteration', 'target', 'best_target', 'kappa'] + \
|
|
205
|
-
list(hyperparams_bounds.keys())
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# start multiprocess evaluation
|
|
209
|
-
worker_ids = list(range(num_workers))
|
|
210
|
-
best_params, best_target = None, -np.inf
|
|
211
|
-
|
|
212
|
-
for it in range(self.gp_iters):
|
|
395
|
+
with multiprocessing.Manager() as manager:
|
|
213
396
|
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
397
|
+
# queue and parallel thread for handing render events
|
|
398
|
+
if show_dashboard:
|
|
399
|
+
queue = manager.Queue()
|
|
400
|
+
dashboard_thread = threading.Thread(
|
|
401
|
+
target=JaxParameterTuning.queue_listener,
|
|
402
|
+
args=(queue, dashboard)
|
|
403
|
+
)
|
|
404
|
+
dashboard_thread.start()
|
|
405
|
+
else:
|
|
406
|
+
queue = None
|
|
219
407
|
|
|
220
|
-
#
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
226
|
-
rows = [None] * num_workers
|
|
408
|
+
# start multiprocess evaluation
|
|
409
|
+
worker_ids = list(range(num_workers))
|
|
410
|
+
best_params, best_target = None, -np.inf
|
|
411
|
+
key = jax.random.PRNGKey(key)
|
|
412
|
+
start_time = time.time()
|
|
227
413
|
|
|
228
|
-
|
|
229
|
-
# to finish before moving to the next
|
|
230
|
-
with get_context(self.pool_context).Pool(processes=num_workers) as pool:
|
|
414
|
+
for it in range(self.gp_iters):
|
|
231
415
|
|
|
232
|
-
#
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
416
|
+
# check if there is enough time left for another iteration
|
|
417
|
+
elapsed = time.time() - start_time
|
|
418
|
+
if elapsed >= self.timeout_tuning:
|
|
419
|
+
print(f'global time limit reached at iteration {it}, aborting')
|
|
420
|
+
break
|
|
421
|
+
|
|
422
|
+
# continue with next iteration
|
|
423
|
+
print('\n' + '*' * 80 +
|
|
424
|
+
f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
|
|
425
|
+
f'starting iteration {it + 1}' +
|
|
426
|
+
'\n' + '*' * 80)
|
|
427
|
+
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
428
|
+
rows = [None] * num_workers
|
|
429
|
+
old_best_target = best_target
|
|
430
|
+
|
|
431
|
+
# create worker pool: note each iteration must wait for all workers
|
|
432
|
+
# to finish before moving to the next
|
|
433
|
+
with multiprocessing.get_context(
|
|
434
|
+
self.pool_context).Pool(processes=num_workers) as pool:
|
|
250
435
|
|
|
251
|
-
#
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
old_kappa = kappas[index]
|
|
262
|
-
kappas[index] = utility.kappa
|
|
263
|
-
|
|
264
|
-
# transform suggestion back to natural space
|
|
265
|
-
rddl_params = {
|
|
266
|
-
name: pf(params[name])
|
|
267
|
-
for (name, (*_, pf)) in self.hyperparams_dict.items()
|
|
268
|
-
}
|
|
269
|
-
|
|
270
|
-
# update the best suggestion so far
|
|
271
|
-
if target > best_target:
|
|
272
|
-
best_params, best_target = rddl_params, target
|
|
436
|
+
# assign jobs to worker pool
|
|
437
|
+
results = [
|
|
438
|
+
pool.apply_async(JaxParameterTuning.objective_function,
|
|
439
|
+
obj_args + (it, obj_kwargs, queue))
|
|
440
|
+
for obj_args in zip(suggested_params, subkeys, worker_ids)
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
# wait for all workers to complete
|
|
444
|
+
while results:
|
|
445
|
+
time.sleep(self.poll_frequency)
|
|
273
446
|
|
|
274
|
-
#
|
|
275
|
-
|
|
276
|
-
|
|
447
|
+
# determine which jobs have completed
|
|
448
|
+
jobs_done = []
|
|
449
|
+
for (i, candidate) in enumerate(results):
|
|
450
|
+
if candidate.ready():
|
|
451
|
+
jobs_done.append(i)
|
|
277
452
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
453
|
+
# get result from completed jobs
|
|
454
|
+
for i in jobs_done[::-1]:
|
|
455
|
+
|
|
456
|
+
# extract and register the new evaluation
|
|
457
|
+
params, target, index, pid = results.pop(i).get()
|
|
458
|
+
optimizer.register(params, target)
|
|
459
|
+
optimizer._gp.fit(
|
|
460
|
+
optimizer.space.params, optimizer.space.target)
|
|
461
|
+
|
|
462
|
+
# update acquisition function and suggest a new point
|
|
463
|
+
suggested_params[index] = optimizer.suggest()
|
|
464
|
+
old_acq_params = acq_params[index]
|
|
465
|
+
acq_params[index] = vars(optimizer.acquisition_function)
|
|
466
|
+
|
|
467
|
+
# transform suggestion back to natural space
|
|
468
|
+
config_params = JaxParameterTuning.search_to_config_params(
|
|
469
|
+
self.hyperparams_dict, params)
|
|
470
|
+
|
|
471
|
+
# update the best suggestion so far
|
|
472
|
+
if target > best_target:
|
|
473
|
+
best_params, best_target = config_params, target
|
|
474
|
+
|
|
475
|
+
rows[index] = [pid, index, it, target,
|
|
476
|
+
best_target, old_acq_params] + \
|
|
477
|
+
list(config_params.values())
|
|
478
|
+
|
|
479
|
+
# print best parameter if found
|
|
480
|
+
if best_target > old_best_target:
|
|
481
|
+
print(f'* found new best average reward {best_target:.6f}')
|
|
482
|
+
|
|
483
|
+
# tune the optimizer here
|
|
484
|
+
self.tune_optimizer(optimizer)
|
|
485
|
+
|
|
486
|
+
# write results of all processes in current iteration to file
|
|
487
|
+
with open(log_file, 'a', newline='') as file:
|
|
488
|
+
writer = csv.writer(file)
|
|
489
|
+
writer.writerows(rows)
|
|
490
|
+
|
|
491
|
+
# update the dashboard tuning
|
|
492
|
+
if show_dashboard:
|
|
493
|
+
dashboard.update_tuning(optimizer, hyperparams_bounds)
|
|
494
|
+
|
|
495
|
+
# stop the queue listener thread
|
|
496
|
+
if show_dashboard:
|
|
497
|
+
queue.put(None)
|
|
498
|
+
dashboard_thread.join()
|
|
282
499
|
|
|
283
500
|
# print summary of results
|
|
284
501
|
elapsed = time.time() - start_time
|
|
@@ -286,420 +503,9 @@ class JaxParameterTuning:
|
|
|
286
503
|
f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
|
|
287
504
|
f' iterations ={it + 1}\n'
|
|
288
505
|
f' best_hyper_parameters={best_params}\n'
|
|
289
|
-
f' best_meta_objective ={best_target}\n')
|
|
506
|
+
f' best_meta_objective ={best_target}\n')
|
|
290
507
|
|
|
291
|
-
|
|
292
|
-
|
|
508
|
+
self.best_params = best_params
|
|
509
|
+
self.optimizer = optimizer
|
|
510
|
+
self.log_file = log_file
|
|
293
511
|
return best_params
|
|
294
|
-
|
|
295
|
-
def _filename(self, name, ext):
|
|
296
|
-
domain_name = ''.join(c for c in self.env.model.domain_name
|
|
297
|
-
if c.isalnum() or c == '_')
|
|
298
|
-
instance_name = ''.join(c for c in self.env.model.instance_name
|
|
299
|
-
if c.isalnum() or c == '_')
|
|
300
|
-
filename = f'{name}_{domain_name}_{instance_name}.{ext}'
|
|
301
|
-
return filename
|
|
302
|
-
|
|
303
|
-
def _save_plot(self, filename):
|
|
304
|
-
try:
|
|
305
|
-
import matplotlib.pyplot as plt
|
|
306
|
-
from sklearn.manifold import MDS
|
|
307
|
-
except Exception as e:
|
|
308
|
-
raise_warning(f'failed to import packages matplotlib or sklearn, '
|
|
309
|
-
f'aborting plot of search space\n{e}', 'red')
|
|
310
|
-
else:
|
|
311
|
-
data = np.loadtxt(filename, delimiter=',', dtype=object)
|
|
312
|
-
data, target = data[1:, 3:], data[1:, 2]
|
|
313
|
-
data = data.astype(np.float64)
|
|
314
|
-
target = target.astype(np.float64)
|
|
315
|
-
target = (target - np.min(target)) / (np.max(target) - np.min(target))
|
|
316
|
-
embedding = MDS(n_components=2, normalized_stress='auto')
|
|
317
|
-
data1 = embedding.fit_transform(data)
|
|
318
|
-
sc = plt.scatter(data1[:, 0], data1[:, 1], c=target, s=4.,
|
|
319
|
-
cmap='seismic', edgecolor='gray',
|
|
320
|
-
linewidth=0.01, alpha=0.4)
|
|
321
|
-
plt.colorbar(sc)
|
|
322
|
-
plt.savefig(self._filename('gp_points', 'pdf'))
|
|
323
|
-
plt.clf()
|
|
324
|
-
plt.close()
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
# ===============================================================================
|
|
328
|
-
#
|
|
329
|
-
# STRAIGHT LINE PLANNING
|
|
330
|
-
#
|
|
331
|
-
# ===============================================================================
|
|
332
|
-
def objective_slp(params, kwargs, key, index):
|
|
333
|
-
|
|
334
|
-
# transform hyper-parameters to natural space
|
|
335
|
-
param_values = [
|
|
336
|
-
pmap(params[name])
|
|
337
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
338
|
-
]
|
|
339
|
-
|
|
340
|
-
# unpack hyper-parameters
|
|
341
|
-
if kwargs['wrapped_bool_actions']:
|
|
342
|
-
std, lr, w, wa = param_values
|
|
343
|
-
else:
|
|
344
|
-
std, lr, w = param_values
|
|
345
|
-
wa = None
|
|
346
|
-
if kwargs['verbose']:
|
|
347
|
-
print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
|
|
348
|
-
|
|
349
|
-
# initialize planning algorithm
|
|
350
|
-
planner = JaxBackpropPlanner(
|
|
351
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
352
|
-
plan=JaxStraightLinePlan(
|
|
353
|
-
initializer=jax.nn.initializers.normal(std),
|
|
354
|
-
**kwargs['plan_kwargs']),
|
|
355
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
356
|
-
**kwargs['planner_kwargs'])
|
|
357
|
-
policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
|
|
358
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
359
|
-
|
|
360
|
-
# initialize policy
|
|
361
|
-
key, subkey = jax.random.split(key)
|
|
362
|
-
policy = JaxOfflineController(
|
|
363
|
-
planner=planner,
|
|
364
|
-
key=subkey,
|
|
365
|
-
eval_hyperparams=policy_hparams,
|
|
366
|
-
train_on_reset=True,
|
|
367
|
-
epochs=kwargs['train_epochs'],
|
|
368
|
-
train_seconds=kwargs['timeout_training'],
|
|
369
|
-
model_params=model_params,
|
|
370
|
-
policy_hyperparams=policy_hparams,
|
|
371
|
-
print_summary=False,
|
|
372
|
-
print_progress=False,
|
|
373
|
-
tqdm_position=index)
|
|
374
|
-
|
|
375
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
376
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
377
|
-
instance=kwargs['instance'],
|
|
378
|
-
vectorized=True,
|
|
379
|
-
enforce_action_constraints=False)
|
|
380
|
-
|
|
381
|
-
# perform training
|
|
382
|
-
average_reward = 0.0
|
|
383
|
-
for trial in range(kwargs['eval_trials']):
|
|
384
|
-
key, subkey = jax.random.split(key)
|
|
385
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
386
|
-
if kwargs['verbose']:
|
|
387
|
-
print(f' [{index}] trial {trial + 1} key={subkey}, '
|
|
388
|
-
f'reward={total_reward}', flush=True)
|
|
389
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
390
|
-
if kwargs['verbose']:
|
|
391
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
392
|
-
return average_reward
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
def power_ten(x):
|
|
396
|
-
return 10.0 ** x
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
class JaxParameterTuningSLP(JaxParameterTuning):
|
|
400
|
-
|
|
401
|
-
def __init__(self, *args,
|
|
402
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
403
|
-
'std': (-5., 2., power_ten),
|
|
404
|
-
'lr': (-5., 2., power_ten),
|
|
405
|
-
'w': (0., 5., power_ten),
|
|
406
|
-
'wa': (0., 5., power_ten)
|
|
407
|
-
},
|
|
408
|
-
**kwargs) -> None:
|
|
409
|
-
'''Creates a new tuning class for straight line planners.
|
|
410
|
-
|
|
411
|
-
:param *args: arguments to pass to parent class
|
|
412
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
413
|
-
weight initialization (std), learning rate (lr), model weight (w), and
|
|
414
|
-
action weight (wa) if wrap_sigmoid and boolean action fluents exist
|
|
415
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
416
|
-
'''
|
|
417
|
-
|
|
418
|
-
super(JaxParameterTuningSLP, self).__init__(
|
|
419
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
420
|
-
|
|
421
|
-
# action parameters required if wrap_sigmoid and boolean action exists
|
|
422
|
-
self.wrapped_bool_actions = []
|
|
423
|
-
if self.plan_kwargs.get('wrap_sigmoid', True):
|
|
424
|
-
for var in self.env.model.action_fluents:
|
|
425
|
-
if self.env.model.variable_ranges[var] == 'bool':
|
|
426
|
-
self.wrapped_bool_actions.append(var)
|
|
427
|
-
if not self.wrapped_bool_actions:
|
|
428
|
-
self.hyperparams_dict.pop('wa', None)
|
|
429
|
-
|
|
430
|
-
def _pickleable_objective_with_kwargs(self):
|
|
431
|
-
objective_fn = objective_slp
|
|
432
|
-
|
|
433
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
434
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
435
|
-
plan_kwargs.pop('initializer', None)
|
|
436
|
-
|
|
437
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
438
|
-
planner_kwargs.pop('rddl', None)
|
|
439
|
-
planner_kwargs.pop('plan', None)
|
|
440
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
441
|
-
|
|
442
|
-
kwargs = {
|
|
443
|
-
'rddl': self.env.model,
|
|
444
|
-
'domain': self.env.domain_text,
|
|
445
|
-
'instance': self.env.instance_text,
|
|
446
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
447
|
-
'timeout_training': self.timeout_training,
|
|
448
|
-
'train_epochs': self.train_epochs,
|
|
449
|
-
'planner_kwargs': planner_kwargs,
|
|
450
|
-
'plan_kwargs': plan_kwargs,
|
|
451
|
-
'verbose': self.verbose,
|
|
452
|
-
'wrapped_bool_actions': self.wrapped_bool_actions,
|
|
453
|
-
'eval_trials': self.eval_trials
|
|
454
|
-
}
|
|
455
|
-
return objective_fn, kwargs
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
# ===============================================================================
|
|
459
|
-
#
|
|
460
|
-
# REPLANNING
|
|
461
|
-
#
|
|
462
|
-
# ===============================================================================
|
|
463
|
-
def objective_replan(params, kwargs, key, index):
|
|
464
|
-
|
|
465
|
-
# transform hyper-parameters to natural space
|
|
466
|
-
param_values = [
|
|
467
|
-
pmap(params[name])
|
|
468
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
469
|
-
]
|
|
470
|
-
|
|
471
|
-
# unpack hyper-parameters
|
|
472
|
-
if kwargs['wrapped_bool_actions']:
|
|
473
|
-
std, lr, w, wa, T = param_values
|
|
474
|
-
else:
|
|
475
|
-
std, lr, w, T = param_values
|
|
476
|
-
wa = None
|
|
477
|
-
if kwargs['verbose']:
|
|
478
|
-
print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
|
|
479
|
-
|
|
480
|
-
# initialize planning algorithm
|
|
481
|
-
planner = JaxBackpropPlanner(
|
|
482
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
483
|
-
plan=JaxStraightLinePlan(
|
|
484
|
-
initializer=jax.nn.initializers.normal(std),
|
|
485
|
-
**kwargs['plan_kwargs']),
|
|
486
|
-
rollout_horizon=T,
|
|
487
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
488
|
-
**kwargs['planner_kwargs'])
|
|
489
|
-
policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
|
|
490
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
491
|
-
|
|
492
|
-
# initialize controller
|
|
493
|
-
key, subkey = jax.random.split(key)
|
|
494
|
-
policy = JaxOnlineController(
|
|
495
|
-
planner=planner,
|
|
496
|
-
key=subkey,
|
|
497
|
-
eval_hyperparams=policy_hparams,
|
|
498
|
-
warm_start=kwargs['use_guess_last_epoch'],
|
|
499
|
-
epochs=kwargs['train_epochs'],
|
|
500
|
-
train_seconds=kwargs['timeout_training'],
|
|
501
|
-
model_params=model_params,
|
|
502
|
-
policy_hyperparams=policy_hparams,
|
|
503
|
-
print_summary=False,
|
|
504
|
-
print_progress=False,
|
|
505
|
-
tqdm_position=index)
|
|
506
|
-
|
|
507
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
508
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
509
|
-
instance=kwargs['instance'],
|
|
510
|
-
vectorized=True,
|
|
511
|
-
enforce_action_constraints=False)
|
|
512
|
-
|
|
513
|
-
# perform training
|
|
514
|
-
average_reward = 0.0
|
|
515
|
-
for trial in range(kwargs['eval_trials']):
|
|
516
|
-
key, subkey = jax.random.split(key)
|
|
517
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
518
|
-
if kwargs['verbose']:
|
|
519
|
-
print(f' [{index}] trial {trial + 1} key={subkey}, '
|
|
520
|
-
f'reward={total_reward}', flush=True)
|
|
521
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
522
|
-
if kwargs['verbose']:
|
|
523
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
524
|
-
return average_reward
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
class JaxParameterTuningSLPReplan(JaxParameterTuningSLP):
|
|
528
|
-
|
|
529
|
-
def __init__(self,
|
|
530
|
-
*args,
|
|
531
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
532
|
-
'std': (-5., 2., power_ten),
|
|
533
|
-
'lr': (-5., 2., power_ten),
|
|
534
|
-
'w': (0., 5., power_ten),
|
|
535
|
-
'wa': (0., 5., power_ten),
|
|
536
|
-
'T': (1, None, int)
|
|
537
|
-
},
|
|
538
|
-
use_guess_last_epoch: bool=True,
|
|
539
|
-
**kwargs) -> None:
|
|
540
|
-
'''Creates a new tuning class for straight line planners.
|
|
541
|
-
|
|
542
|
-
:param *args: arguments to pass to parent class
|
|
543
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
544
|
-
weight initialization (std), learning rate (lr), model weight (w),
|
|
545
|
-
action weight (wa) if wrap_sigmoid and boolean action fluents exist, and
|
|
546
|
-
lookahead horizon (T)
|
|
547
|
-
:param use_guess_last_epoch: use the trained parameters from previous
|
|
548
|
-
decision to warm-start next decision
|
|
549
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
550
|
-
'''
|
|
551
|
-
|
|
552
|
-
super(JaxParameterTuningSLPReplan, self).__init__(
|
|
553
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
554
|
-
|
|
555
|
-
self.use_guess_last_epoch = use_guess_last_epoch
|
|
556
|
-
|
|
557
|
-
# set upper range of lookahead horizon to environment horizon
|
|
558
|
-
if self.hyperparams_dict['T'][1] is None:
|
|
559
|
-
self.hyperparams_dict['T'] = (1, self.env.horizon, int)
|
|
560
|
-
|
|
561
|
-
def _pickleable_objective_with_kwargs(self):
|
|
562
|
-
objective_fn = objective_replan
|
|
563
|
-
|
|
564
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
565
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
566
|
-
plan_kwargs.pop('initializer', None)
|
|
567
|
-
|
|
568
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
569
|
-
planner_kwargs.pop('rddl', None)
|
|
570
|
-
planner_kwargs.pop('plan', None)
|
|
571
|
-
planner_kwargs.pop('rollout_horizon', None)
|
|
572
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
573
|
-
|
|
574
|
-
kwargs = {
|
|
575
|
-
'rddl': self.env.model,
|
|
576
|
-
'domain': self.env.domain_text,
|
|
577
|
-
'instance': self.env.instance_text,
|
|
578
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
579
|
-
'timeout_training': self.timeout_training,
|
|
580
|
-
'train_epochs': self.train_epochs,
|
|
581
|
-
'planner_kwargs': planner_kwargs,
|
|
582
|
-
'plan_kwargs': plan_kwargs,
|
|
583
|
-
'verbose': self.verbose,
|
|
584
|
-
'wrapped_bool_actions': self.wrapped_bool_actions,
|
|
585
|
-
'eval_trials': self.eval_trials,
|
|
586
|
-
'use_guess_last_epoch': self.use_guess_last_epoch
|
|
587
|
-
}
|
|
588
|
-
return objective_fn, kwargs
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
# ===============================================================================
|
|
592
|
-
#
|
|
593
|
-
# DEEP REACTIVE POLICIES
|
|
594
|
-
#
|
|
595
|
-
# ===============================================================================
|
|
596
|
-
def objective_drp(params, kwargs, key, index):
|
|
597
|
-
|
|
598
|
-
# transform hyper-parameters to natural space
|
|
599
|
-
param_values = [
|
|
600
|
-
pmap(params[name])
|
|
601
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
602
|
-
]
|
|
603
|
-
|
|
604
|
-
# unpack hyper-parameters
|
|
605
|
-
lr, w, layers, neurons = param_values
|
|
606
|
-
if kwargs['verbose']:
|
|
607
|
-
print(f'[{index}] key={key}, lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
|
|
608
|
-
|
|
609
|
-
# initialize planning algorithm
|
|
610
|
-
planner = JaxBackpropPlanner(
|
|
611
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
612
|
-
plan=JaxDeepReactivePolicy(
|
|
613
|
-
topology=[neurons] * layers,
|
|
614
|
-
**kwargs['plan_kwargs']),
|
|
615
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
616
|
-
**kwargs['planner_kwargs'])
|
|
617
|
-
policy_hparams = {name: None for name in planner._action_bounds}
|
|
618
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
619
|
-
|
|
620
|
-
# initialize policy
|
|
621
|
-
key, subkey = jax.random.split(key)
|
|
622
|
-
policy = JaxOfflineController(
|
|
623
|
-
planner=planner,
|
|
624
|
-
key=subkey,
|
|
625
|
-
eval_hyperparams=policy_hparams,
|
|
626
|
-
train_on_reset=True,
|
|
627
|
-
epochs=kwargs['train_epochs'],
|
|
628
|
-
train_seconds=kwargs['timeout_training'],
|
|
629
|
-
model_params=model_params,
|
|
630
|
-
policy_hyperparams=policy_hparams,
|
|
631
|
-
print_summary=False,
|
|
632
|
-
print_progress=False,
|
|
633
|
-
tqdm_position=index)
|
|
634
|
-
|
|
635
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
636
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
637
|
-
instance=kwargs['instance'],
|
|
638
|
-
vectorized=True,
|
|
639
|
-
enforce_action_constraints=False)
|
|
640
|
-
|
|
641
|
-
# perform training
|
|
642
|
-
average_reward = 0.0
|
|
643
|
-
for trial in range(kwargs['eval_trials']):
|
|
644
|
-
key, subkey = jax.random.split(key)
|
|
645
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
646
|
-
if kwargs['verbose']:
|
|
647
|
-
print(f' [{index}] trial {trial + 1} key={subkey}, '
|
|
648
|
-
f'reward={total_reward}', flush=True)
|
|
649
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
650
|
-
if kwargs['verbose']:
|
|
651
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
652
|
-
return average_reward
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
def power_two_int(x):
|
|
656
|
-
return 2 ** int(x)
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
class JaxParameterTuningDRP(JaxParameterTuning):
|
|
660
|
-
|
|
661
|
-
def __init__(self, *args,
|
|
662
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
663
|
-
'lr': (-7., 2., power_ten),
|
|
664
|
-
'w': (0., 5., power_ten),
|
|
665
|
-
'layers': (1., 3., int),
|
|
666
|
-
'neurons': (2., 9., power_two_int)
|
|
667
|
-
},
|
|
668
|
-
**kwargs) -> None:
|
|
669
|
-
'''Creates a new tuning class for deep reactive policies.
|
|
670
|
-
|
|
671
|
-
:param *args: arguments to pass to parent class
|
|
672
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
673
|
-
learning rate (lr), model weight (w), number of hidden layers (layers)
|
|
674
|
-
and number of neurons per hidden layer (neurons)
|
|
675
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
676
|
-
'''
|
|
677
|
-
|
|
678
|
-
super(JaxParameterTuningDRP, self).__init__(
|
|
679
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
680
|
-
|
|
681
|
-
def _pickleable_objective_with_kwargs(self):
|
|
682
|
-
objective_fn = objective_drp
|
|
683
|
-
|
|
684
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
685
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
686
|
-
plan_kwargs.pop('topology', None)
|
|
687
|
-
|
|
688
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
689
|
-
planner_kwargs.pop('rddl', None)
|
|
690
|
-
planner_kwargs.pop('plan', None)
|
|
691
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
692
|
-
|
|
693
|
-
kwargs = {
|
|
694
|
-
'rddl': self.env.model,
|
|
695
|
-
'domain': self.env.domain_text,
|
|
696
|
-
'instance': self.env.instance_text,
|
|
697
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
698
|
-
'timeout_training': self.timeout_training,
|
|
699
|
-
'train_epochs': self.train_epochs,
|
|
700
|
-
'planner_kwargs': planner_kwargs,
|
|
701
|
-
'plan_kwargs': plan_kwargs,
|
|
702
|
-
'verbose': self.verbose,
|
|
703
|
-
'eval_trials': self.eval_trials
|
|
704
|
-
}
|
|
705
|
-
return objective_fn, kwargs
|