pyRDDLGym-jax 0.5__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +336 -472
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +392 -567
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/METADATA +167 -111
- pyRDDLGym_jax-1.1.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
from copy import deepcopy
|
|
2
1
|
import csv
|
|
3
2
|
import datetime
|
|
4
|
-
|
|
3
|
+
import threading
|
|
4
|
+
import multiprocessing
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
-
|
|
7
|
+
import traceback
|
|
8
|
+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
8
9
|
import warnings
|
|
9
10
|
warnings.filterwarnings("ignore")
|
|
10
11
|
|
|
12
|
+
from sklearn.gaussian_process.kernels import Matern, ConstantKernel
|
|
11
13
|
from bayes_opt import BayesianOptimization
|
|
12
14
|
from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
|
|
13
15
|
import jax
|
|
@@ -18,24 +20,41 @@ from pyRDDLGym.core.env import RDDLEnv
|
|
|
18
20
|
|
|
19
21
|
from pyRDDLGym_jax.core.planner import (
|
|
20
22
|
JaxBackpropPlanner,
|
|
21
|
-
JaxStraightLinePlan,
|
|
22
|
-
JaxDeepReactivePolicy,
|
|
23
23
|
JaxOfflineController,
|
|
24
|
-
JaxOnlineController
|
|
24
|
+
JaxOnlineController,
|
|
25
|
+
load_config_from_string
|
|
25
26
|
)
|
|
26
27
|
|
|
27
|
-
|
|
28
|
+
# try to load the dash board
|
|
29
|
+
try:
|
|
30
|
+
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
31
|
+
except Exception:
|
|
32
|
+
raise_warning('Failed to load the dashboard visualization tool: '
|
|
33
|
+
'please make sure you have installed the required packages.',
|
|
34
|
+
'red')
|
|
35
|
+
traceback.print_exc()
|
|
36
|
+
JaxPlannerDashboard = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Hyperparameter:
|
|
40
|
+
'''A generic hyper-parameter of the planner that can be tuned.'''
|
|
41
|
+
|
|
42
|
+
def __init__(self, tag: str, lower_bound: float, upper_bound: float,
|
|
43
|
+
search_to_config_map: Callable) -> None:
|
|
44
|
+
self.tag = tag
|
|
45
|
+
self.lower_bound = lower_bound
|
|
46
|
+
self.upper_bound = upper_bound
|
|
47
|
+
self.search_to_config_map = search_to_config_map
|
|
48
|
+
|
|
49
|
+
def __str__(self) -> str:
|
|
50
|
+
return (f'{self.search_to_config_map.__name__} '
|
|
51
|
+
f': [{self.lower_bound}, {self.upper_bound}] -> {self.tag}')
|
|
28
52
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# 1. straight line plan
|
|
35
|
-
# 2. re-planning
|
|
36
|
-
# 3. deep reactive policies
|
|
37
|
-
#
|
|
38
|
-
# ===============================================================================
|
|
53
|
+
|
|
54
|
+
Kwargs = Dict[str, Any]
|
|
55
|
+
ParameterValues = Dict[str, Any]
|
|
56
|
+
Hyperparameters = Iterable[Hyperparameter]
|
|
57
|
+
|
|
39
58
|
COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
|
|
40
59
|
|
|
41
60
|
|
|
@@ -43,14 +62,13 @@ class JaxParameterTuning:
|
|
|
43
62
|
'''A general-purpose class for tuning a Jax planner.'''
|
|
44
63
|
|
|
45
64
|
def __init__(self, env: RDDLEnv,
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
timeout_tuning: float=np.inf,
|
|
65
|
+
config_template: str,
|
|
66
|
+
hyperparams: Hyperparameters,
|
|
67
|
+
online: bool,
|
|
50
68
|
eval_trials: int=5,
|
|
69
|
+
rollouts_per_trial: int=1,
|
|
51
70
|
verbose: bool=True,
|
|
52
|
-
|
|
53
|
-
plan_kwargs: Optional[Kwargs]=None,
|
|
71
|
+
timeout_tuning: float=np.inf,
|
|
54
72
|
pool_context: str='spawn',
|
|
55
73
|
num_workers: int=1,
|
|
56
74
|
poll_frequency: float=0.2,
|
|
@@ -62,23 +80,20 @@ class JaxParameterTuning:
|
|
|
62
80
|
on the given RDDL domain and instance.
|
|
63
81
|
|
|
64
82
|
:param env: the RDDLEnv describing the MDP to optimize
|
|
65
|
-
:param
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
:param timeout_training: the maximum amount of time to spend training per
|
|
72
|
-
trial/decision step (in seconds)
|
|
83
|
+
:param config_template: base configuration file content to tune: regex
|
|
84
|
+
matches are specified directly in the config and map to keys in the
|
|
85
|
+
hyperparams_dict field
|
|
86
|
+
:param hyperparams: list of hyper-parameters to regex replace in the
|
|
87
|
+
config template during tuning
|
|
88
|
+
:param online: whether the planner is optimized online or offline
|
|
73
89
|
:param timeout_tuning: the maximum amount of time to spend tuning
|
|
74
90
|
hyperparameters in general (in seconds)
|
|
75
91
|
:param eval_trials: how many trials to perform independent training
|
|
76
92
|
in order to estimate the return for each set of hyper-parameters
|
|
93
|
+
:param rollouts_per_trial: how many rollouts to perform during evaluation
|
|
94
|
+
at the end of each training trial (only applies when online=False)
|
|
77
95
|
:param verbose: whether to print intermediate results of tuning
|
|
78
|
-
:param
|
|
79
|
-
:param plan_kwargs: additional arguments to feed to the plan/policy
|
|
80
|
-
:param pool_context: context for multiprocessing pool (defaults to
|
|
81
|
-
"spawn")
|
|
96
|
+
:param pool_context: context for multiprocessing pool (default "spawn")
|
|
82
97
|
:param num_workers: how many points to evaluate in parallel
|
|
83
98
|
:param poll_frequency: how often (in seconds) to poll for completed
|
|
84
99
|
jobs, necessary if num_workers > 1
|
|
@@ -88,21 +103,21 @@ class JaxParameterTuning:
|
|
|
88
103
|
during initialization
|
|
89
104
|
:param gp_params: additional parameters to feed to Bayesian optimizer
|
|
90
105
|
after initialization optimization
|
|
91
|
-
'''
|
|
92
|
-
|
|
106
|
+
'''
|
|
107
|
+
# objective parameters
|
|
93
108
|
self.env = env
|
|
109
|
+
self.config_template = config_template
|
|
110
|
+
hyperparams_dict = {hyper_param.tag: hyper_param
|
|
111
|
+
for hyper_param in hyperparams
|
|
112
|
+
if hyper_param.tag in config_template}
|
|
94
113
|
self.hyperparams_dict = hyperparams_dict
|
|
95
|
-
self.
|
|
96
|
-
self.timeout_training = timeout_training
|
|
97
|
-
self.timeout_tuning = timeout_tuning
|
|
114
|
+
self.online = online
|
|
98
115
|
self.eval_trials = eval_trials
|
|
116
|
+
self.rollouts_per_trial = rollouts_per_trial
|
|
99
117
|
self.verbose = verbose
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self.
|
|
103
|
-
if plan_kwargs is None:
|
|
104
|
-
plan_kwargs = {}
|
|
105
|
-
self.plan_kwargs = plan_kwargs
|
|
118
|
+
|
|
119
|
+
# Bayesian parameters
|
|
120
|
+
self.timeout_tuning = timeout_tuning
|
|
106
121
|
self.pool_context = pool_context
|
|
107
122
|
self.num_workers = num_workers
|
|
108
123
|
self.poll_frequency = poll_frequency
|
|
@@ -111,20 +126,33 @@ class JaxParameterTuning:
|
|
|
111
126
|
gp_init_kwargs = {}
|
|
112
127
|
self.gp_init_kwargs = gp_init_kwargs
|
|
113
128
|
if gp_params is None:
|
|
114
|
-
gp_params = {'n_restarts_optimizer':
|
|
129
|
+
gp_params = {'n_restarts_optimizer': 25,
|
|
130
|
+
'kernel': self.make_default_kernel()}
|
|
115
131
|
self.gp_params = gp_params
|
|
116
|
-
|
|
117
|
-
# create acquisition function
|
|
118
132
|
if acquisition is None:
|
|
119
133
|
num_samples = self.gp_iters * self.num_workers
|
|
120
|
-
acquisition = JaxParameterTuning.
|
|
134
|
+
acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
|
|
121
135
|
self.acquisition = acquisition
|
|
122
136
|
|
|
137
|
+
@staticmethod
|
|
138
|
+
def make_default_kernel():
|
|
139
|
+
weight1 = ConstantKernel(1.0, (0.01, 100.0))
|
|
140
|
+
weight2 = ConstantKernel(1.0, (0.01, 100.0))
|
|
141
|
+
weight3 = ConstantKernel(1.0, (0.01, 100.0))
|
|
142
|
+
kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
|
|
143
|
+
kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
|
|
144
|
+
kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
|
|
145
|
+
return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
|
|
146
|
+
|
|
123
147
|
def summarize_hyperparameters(self) -> None:
|
|
148
|
+
hyper_params_table = []
|
|
149
|
+
for (_, param) in self.hyperparams_dict.items():
|
|
150
|
+
hyper_params_table.append(f' {str(param)}')
|
|
151
|
+
hyper_params_table = '\n'.join(hyper_params_table)
|
|
124
152
|
print(f'hyperparameter optimizer parameters:\n'
|
|
125
|
-
f' tuned_hyper_parameters
|
|
153
|
+
f' tuned_hyper_parameters =\n{hyper_params_table}\n'
|
|
126
154
|
f' initialization_args ={self.gp_init_kwargs}\n'
|
|
127
|
-
f'
|
|
155
|
+
f' gp_params ={self.gp_params}\n'
|
|
128
156
|
f' tuning_iterations ={self.gp_iters}\n'
|
|
129
157
|
f' tuning_timeout ={self.timeout_tuning}\n'
|
|
130
158
|
f' tuning_batch_size ={self.num_workers}\n'
|
|
@@ -132,43 +160,233 @@ class JaxParameterTuning:
|
|
|
132
160
|
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
133
161
|
f'meta-objective parameters:\n'
|
|
134
162
|
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
135
|
-
f'
|
|
136
|
-
f' planning_timeout_per_trial={self.timeout_training}\n'
|
|
163
|
+
f' rollouts_per_trial ={self.rollouts_per_trial}\n'
|
|
137
164
|
f' acquisition_fn ={self.acquisition}')
|
|
138
165
|
|
|
139
166
|
@staticmethod
|
|
140
|
-
def
|
|
167
|
+
def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
|
|
168
|
+
kappa1: float=10.0, kappa2: float=1.0) -> UpperConfidenceBound:
|
|
141
169
|
acq_fn = UpperConfidenceBound(
|
|
142
170
|
kappa=kappa1,
|
|
143
171
|
exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
|
|
144
|
-
exploration_decay_delay=n_delay_samples
|
|
172
|
+
exploration_decay_delay=n_delay_samples
|
|
173
|
+
)
|
|
145
174
|
return acq_fn
|
|
146
175
|
|
|
147
|
-
|
|
148
|
-
|
|
176
|
+
@staticmethod
|
|
177
|
+
def search_to_config_params(hyper_params: Hyperparameters,
|
|
178
|
+
params: ParameterValues) -> ParameterValues:
|
|
179
|
+
config_params = {
|
|
180
|
+
tag: param.search_to_config_map(params[tag])
|
|
181
|
+
for (tag, param) in hyper_params.items()
|
|
182
|
+
}
|
|
183
|
+
return config_params
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def config_from_template(config_template: str,
|
|
187
|
+
config_params: ParameterValues) -> str:
|
|
188
|
+
config_string = config_template
|
|
189
|
+
for (tag, param_value) in config_params.items():
|
|
190
|
+
config_string = config_string.replace(tag, str(param_value))
|
|
191
|
+
return config_string
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def best_config(self) -> str:
|
|
195
|
+
return self.config_from_template(self.config_template, self.best_params)
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def queue_listener(queue, dashboard):
|
|
199
|
+
while True:
|
|
200
|
+
args = queue.get()
|
|
201
|
+
if args is None:
|
|
202
|
+
break
|
|
203
|
+
elif len(args) == 2:
|
|
204
|
+
dashboard.update_experiment(*args)
|
|
205
|
+
else:
|
|
206
|
+
dashboard.register_experiment(*args)
|
|
149
207
|
|
|
150
208
|
@staticmethod
|
|
151
|
-
def
|
|
152
|
-
|
|
209
|
+
def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
210
|
+
rollouts_per_trial, verbose, viz, queue):
|
|
211
|
+
average_reward = 0.0
|
|
212
|
+
for trial in range(num_trials):
|
|
213
|
+
key, subkey = jax.random.split(key)
|
|
214
|
+
|
|
215
|
+
# for the dashboard
|
|
216
|
+
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
217
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
218
|
+
queue.put((
|
|
219
|
+
experiment_id,
|
|
220
|
+
JaxPlannerDashboard.get_planner_info(planner),
|
|
221
|
+
subkey[0],
|
|
222
|
+
viz
|
|
223
|
+
))
|
|
224
|
+
|
|
225
|
+
# train the policy
|
|
226
|
+
callback = None
|
|
227
|
+
for callback in planner.optimize_generator(key=subkey, **train_args):
|
|
228
|
+
if queue is not None and queue.empty():
|
|
229
|
+
queue.put((experiment_id, callback))
|
|
230
|
+
best_params = None if callback is None else callback['best_params']
|
|
231
|
+
|
|
232
|
+
# evaluate the policy in the real environment
|
|
233
|
+
policy = JaxOfflineController(
|
|
234
|
+
planner=planner, key=subkey, tqdm_position=index,
|
|
235
|
+
params=best_params, train_on_reset=False)
|
|
236
|
+
total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
|
|
237
|
+
seed=np.array(subkey)[0])['mean']
|
|
238
|
+
|
|
239
|
+
# update average reward
|
|
240
|
+
if verbose:
|
|
241
|
+
iters = None if callback is None else callback['iteration']
|
|
242
|
+
print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
|
|
243
|
+
f'reward={total_reward:.6f}, iters={iters}', flush=True)
|
|
244
|
+
average_reward += total_reward / num_trials
|
|
245
|
+
|
|
246
|
+
if verbose:
|
|
247
|
+
print(f'[{index}] average reward={average_reward:.6f}', flush=True)
|
|
248
|
+
return average_reward
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def online_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
252
|
+
verbose, viz, queue):
|
|
253
|
+
average_reward = 0.0
|
|
254
|
+
for trial in range(num_trials):
|
|
255
|
+
key, subkey = jax.random.split(key)
|
|
256
|
+
|
|
257
|
+
# for the dashboard
|
|
258
|
+
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
259
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
260
|
+
queue.put((
|
|
261
|
+
experiment_id,
|
|
262
|
+
JaxPlannerDashboard.get_planner_info(planner),
|
|
263
|
+
subkey[0],
|
|
264
|
+
viz
|
|
265
|
+
))
|
|
266
|
+
|
|
267
|
+
# initialize the online policy
|
|
268
|
+
policy = JaxOnlineController(
|
|
269
|
+
planner=planner, key=subkey, tqdm_position=index, **train_args)
|
|
270
|
+
|
|
271
|
+
# evaluate the policy in the real environment
|
|
272
|
+
total_reward = 0.0
|
|
273
|
+
callback = None
|
|
274
|
+
state, _ = env.reset(seed=np.array(subkey)[0])
|
|
275
|
+
elapsed_time = 0.0
|
|
276
|
+
for step in range(env.horizon):
|
|
277
|
+
action = policy.sample_action(state)
|
|
278
|
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
|
279
|
+
total_reward += reward
|
|
280
|
+
done = terminated or truncated
|
|
281
|
+
state = next_state
|
|
282
|
+
callback = policy.callback
|
|
283
|
+
elapsed_time += callback['elapsed_time']
|
|
284
|
+
callback['iteration'] = step
|
|
285
|
+
callback['progress'] = int(100 * (step + 1.) / env.horizon)
|
|
286
|
+
callback['elapsed_time'] = elapsed_time
|
|
287
|
+
if queue is not None and queue.empty():
|
|
288
|
+
queue.put((experiment_id, callback))
|
|
289
|
+
if done:
|
|
290
|
+
break
|
|
291
|
+
|
|
292
|
+
# update average reward
|
|
293
|
+
if verbose:
|
|
294
|
+
iters = None if callback is None else callback['iteration']
|
|
295
|
+
print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
|
|
296
|
+
f'reward={total_reward:.6f}, iters={iters}', flush=True)
|
|
297
|
+
average_reward += total_reward / num_trials
|
|
298
|
+
|
|
299
|
+
if verbose:
|
|
300
|
+
print(f'[{index}] average reward={average_reward:.6f}', flush=True)
|
|
301
|
+
return average_reward
|
|
302
|
+
|
|
303
|
+
@staticmethod
|
|
304
|
+
def objective_function(params: ParameterValues,
|
|
305
|
+
key: jax.random.PRNGKey,
|
|
306
|
+
index: int,
|
|
307
|
+
iteration: int,
|
|
308
|
+
kwargs: Kwargs,
|
|
309
|
+
queue: object) -> Tuple[ParameterValues, float, int, int]:
|
|
310
|
+
'''A pickleable objective function to evaluate a single hyper-parameter
|
|
311
|
+
configuration.'''
|
|
312
|
+
|
|
313
|
+
hyperparams_dict = kwargs['hyperparams_dict']
|
|
314
|
+
config_template = kwargs['config_template']
|
|
315
|
+
online = kwargs['online']
|
|
316
|
+
domain = kwargs['domain']
|
|
317
|
+
instance = kwargs['instance']
|
|
318
|
+
num_trials = kwargs['eval_trials']
|
|
319
|
+
rollouts_per_trial = kwargs['rollouts_per_trial']
|
|
320
|
+
viz = kwargs['viz']
|
|
321
|
+
verbose = kwargs['verbose']
|
|
322
|
+
|
|
323
|
+
# config string substitution and parsing
|
|
324
|
+
config_params = JaxParameterTuning.search_to_config_params(hyperparams_dict, params)
|
|
325
|
+
if verbose:
|
|
326
|
+
config_param_str = ', '.join(
|
|
327
|
+
f'{k}={v}' for (k, v) in config_params.items())
|
|
328
|
+
print(f'[{index}] key={key[0]}, {config_param_str}', flush=True)
|
|
329
|
+
config_string = JaxParameterTuning.config_from_template(config_template, config_params)
|
|
330
|
+
planner_args, _, train_args = load_config_from_string(config_string)
|
|
331
|
+
|
|
332
|
+
# remove keywords that should not be in the tuner
|
|
333
|
+
train_args.pop('dashboard', None)
|
|
334
|
+
|
|
335
|
+
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
336
|
+
env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
|
|
337
|
+
|
|
338
|
+
# run planning algorithm
|
|
339
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
340
|
+
if online:
|
|
341
|
+
average_reward = JaxParameterTuning.online_trials(
|
|
342
|
+
env, planner, train_args, key, iteration, index, num_trials,
|
|
343
|
+
verbose, viz, queue
|
|
344
|
+
)
|
|
345
|
+
else:
|
|
346
|
+
average_reward = JaxParameterTuning.offline_trials(
|
|
347
|
+
env, planner, train_args, key, iteration, index,
|
|
348
|
+
num_trials, rollouts_per_trial, verbose, viz, queue
|
|
349
|
+
)
|
|
350
|
+
|
|
153
351
|
pid = os.getpid()
|
|
154
|
-
return
|
|
155
|
-
|
|
156
|
-
def
|
|
157
|
-
|
|
158
|
-
|
|
352
|
+
return params, average_reward, index, pid
|
|
353
|
+
|
|
354
|
+
def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
|
|
355
|
+
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
356
|
+
print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
|
|
357
|
+
|
|
358
|
+
def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
|
|
159
359
|
'''Tunes the hyper-parameters for Jax planner, returns the best found.'''
|
|
360
|
+
|
|
160
361
|
self.summarize_hyperparameters()
|
|
161
362
|
|
|
162
|
-
|
|
363
|
+
# clear and prepare output file
|
|
364
|
+
with open(log_file, 'w', newline='') as file:
|
|
365
|
+
writer = csv.writer(file)
|
|
366
|
+
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
163
367
|
|
|
164
|
-
#
|
|
165
|
-
|
|
166
|
-
|
|
368
|
+
# create a dash-board for visualizing experiment runs
|
|
369
|
+
if show_dashboard and JaxPlannerDashboard is not None:
|
|
370
|
+
dashboard = JaxPlannerDashboard()
|
|
371
|
+
dashboard.launch()
|
|
167
372
|
|
|
373
|
+
# objective function auxiliary data
|
|
374
|
+
obj_kwargs = {
|
|
375
|
+
'hyperparams_dict': self.hyperparams_dict,
|
|
376
|
+
'config_template': self.config_template,
|
|
377
|
+
'online': self.online,
|
|
378
|
+
'domain': self.env.domain_text,
|
|
379
|
+
'instance': self.env.instance_text,
|
|
380
|
+
'eval_trials': self.eval_trials,
|
|
381
|
+
'rollouts_per_trial': self.rollouts_per_trial,
|
|
382
|
+
'viz': self.env._visualizer,
|
|
383
|
+
'verbose': self.verbose
|
|
384
|
+
}
|
|
385
|
+
|
|
168
386
|
# create optimizer
|
|
169
387
|
hyperparams_bounds = {
|
|
170
|
-
|
|
171
|
-
for (
|
|
388
|
+
tag: (param.lower_bound, param.upper_bound)
|
|
389
|
+
for (tag, param) in self.hyperparams_dict.items()
|
|
172
390
|
}
|
|
173
391
|
optimizer = BayesianOptimization(
|
|
174
392
|
f=None,
|
|
@@ -182,91 +400,116 @@ class JaxParameterTuning:
|
|
|
182
400
|
|
|
183
401
|
# suggest initial parameters to evaluate
|
|
184
402
|
num_workers = self.num_workers
|
|
185
|
-
|
|
403
|
+
suggested_params, acq_params = [], []
|
|
186
404
|
for _ in range(num_workers):
|
|
187
405
|
probe = optimizer.suggest()
|
|
188
|
-
|
|
406
|
+
suggested_params.append(probe)
|
|
189
407
|
acq_params.append(vars(optimizer.acquisition_function))
|
|
190
408
|
|
|
191
|
-
|
|
192
|
-
filename = self._filename(filename, 'csv')
|
|
193
|
-
with open(filename, 'w', newline='') as file:
|
|
194
|
-
writer = csv.writer(file)
|
|
195
|
-
writer.writerow(COLUMNS + list(hyperparams_bounds.keys()))
|
|
196
|
-
|
|
197
|
-
# start multiprocess evaluation
|
|
198
|
-
worker_ids = list(range(num_workers))
|
|
199
|
-
best_params, best_target = None, -np.inf
|
|
200
|
-
|
|
201
|
-
for it in range(self.gp_iters):
|
|
409
|
+
with multiprocessing.Manager() as manager:
|
|
202
410
|
|
|
203
|
-
#
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
411
|
+
# queue and parallel thread for handing render events
|
|
412
|
+
if show_dashboard:
|
|
413
|
+
queue = manager.Queue()
|
|
414
|
+
dashboard_thread = threading.Thread(
|
|
415
|
+
target=JaxParameterTuning.queue_listener,
|
|
416
|
+
args=(queue, dashboard)
|
|
417
|
+
)
|
|
418
|
+
dashboard_thread.start()
|
|
419
|
+
else:
|
|
420
|
+
queue = None
|
|
208
421
|
|
|
209
|
-
#
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
215
|
-
rows = [None] * num_workers
|
|
422
|
+
# start multiprocess evaluation
|
|
423
|
+
worker_ids = list(range(num_workers))
|
|
424
|
+
best_params, best_target = None, -np.inf
|
|
425
|
+
key = jax.random.PRNGKey(key)
|
|
426
|
+
start_time = time.time()
|
|
216
427
|
|
|
217
|
-
|
|
218
|
-
# to finish before moving to the next
|
|
219
|
-
with get_context(self.pool_context).Pool(processes=num_workers) as pool:
|
|
428
|
+
for it in range(self.gp_iters):
|
|
220
429
|
|
|
221
|
-
#
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
430
|
+
# check if there is enough time left for another iteration
|
|
431
|
+
elapsed = time.time() - start_time
|
|
432
|
+
if elapsed >= self.timeout_tuning:
|
|
433
|
+
print(f'global time limit reached at iteration {it}, aborting')
|
|
434
|
+
break
|
|
435
|
+
|
|
436
|
+
# continue with next iteration
|
|
437
|
+
print('\n' + '*' * 80 +
|
|
438
|
+
f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
|
|
439
|
+
f'starting iteration {it + 1}' +
|
|
440
|
+
'\n' + '*' * 80)
|
|
441
|
+
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
442
|
+
rows = [None] * num_workers
|
|
443
|
+
old_best_target = best_target
|
|
444
|
+
|
|
445
|
+
# create worker pool: note each iteration must wait for all workers
|
|
446
|
+
# to finish before moving to the next
|
|
447
|
+
with multiprocessing.get_context(
|
|
448
|
+
self.pool_context).Pool(processes=num_workers) as pool:
|
|
239
449
|
|
|
240
|
-
#
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
acq_params[index] = vars(optimizer.acquisition_function)
|
|
251
|
-
|
|
252
|
-
# transform suggestion back to natural space
|
|
253
|
-
rddl_params = {
|
|
254
|
-
name: pf(params[name])
|
|
255
|
-
for (name, (*_, pf)) in self.hyperparams_dict.items()
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
# update the best suggestion so far
|
|
259
|
-
if target > best_target:
|
|
260
|
-
best_params, best_target = rddl_params, target
|
|
450
|
+
# assign jobs to worker pool
|
|
451
|
+
results = [
|
|
452
|
+
pool.apply_async(JaxParameterTuning.objective_function,
|
|
453
|
+
obj_args + (it, obj_kwargs, queue))
|
|
454
|
+
for obj_args in zip(suggested_params, subkeys, worker_ids)
|
|
455
|
+
]
|
|
456
|
+
|
|
457
|
+
# wait for all workers to complete
|
|
458
|
+
while results:
|
|
459
|
+
time.sleep(self.poll_frequency)
|
|
261
460
|
|
|
262
|
-
#
|
|
263
|
-
|
|
264
|
-
|
|
461
|
+
# determine which jobs have completed
|
|
462
|
+
jobs_done = []
|
|
463
|
+
for (i, candidate) in enumerate(results):
|
|
464
|
+
if candidate.ready():
|
|
465
|
+
jobs_done.append(i)
|
|
265
466
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
467
|
+
# get result from completed jobs
|
|
468
|
+
for i in jobs_done[::-1]:
|
|
469
|
+
|
|
470
|
+
# extract and register the new evaluation
|
|
471
|
+
params, target, index, pid = results.pop(i).get()
|
|
472
|
+
optimizer.register(params, target)
|
|
473
|
+
optimizer._gp.fit(
|
|
474
|
+
optimizer.space.params, optimizer.space.target)
|
|
475
|
+
|
|
476
|
+
# update acquisition function and suggest a new point
|
|
477
|
+
suggested_params[index] = optimizer.suggest()
|
|
478
|
+
old_acq_params = acq_params[index]
|
|
479
|
+
acq_params[index] = vars(optimizer.acquisition_function)
|
|
480
|
+
|
|
481
|
+
# transform suggestion back to natural space
|
|
482
|
+
config_params = JaxParameterTuning.search_to_config_params(
|
|
483
|
+
self.hyperparams_dict, params)
|
|
484
|
+
|
|
485
|
+
# update the best suggestion so far
|
|
486
|
+
if target > best_target:
|
|
487
|
+
best_params, best_target = config_params, target
|
|
488
|
+
|
|
489
|
+
rows[index] = [pid, index, it, target,
|
|
490
|
+
best_target, old_acq_params] + \
|
|
491
|
+
list(config_params.values())
|
|
492
|
+
|
|
493
|
+
# print best parameter if found
|
|
494
|
+
if best_target > old_best_target:
|
|
495
|
+
print(f'* found new best average reward {best_target:.6f}')
|
|
496
|
+
|
|
497
|
+
# tune the optimizer here
|
|
498
|
+
self.tune_optimizer(optimizer)
|
|
499
|
+
|
|
500
|
+
# write results of all processes in current iteration to file
|
|
501
|
+
with open(log_file, 'a', newline='') as file:
|
|
502
|
+
writer = csv.writer(file)
|
|
503
|
+
writer.writerows(rows)
|
|
504
|
+
|
|
505
|
+
# update the dashboard tuning
|
|
506
|
+
if show_dashboard:
|
|
507
|
+
dashboard.update_tuning(optimizer, hyperparams_bounds)
|
|
508
|
+
|
|
509
|
+
# stop the queue listener thread
|
|
510
|
+
if show_dashboard:
|
|
511
|
+
queue.put(None)
|
|
512
|
+
dashboard_thread.join()
|
|
270
513
|
|
|
271
514
|
# print summary of results
|
|
272
515
|
elapsed = time.time() - start_time
|
|
@@ -274,427 +517,9 @@ class JaxParameterTuning:
|
|
|
274
517
|
f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
|
|
275
518
|
f' iterations ={it + 1}\n'
|
|
276
519
|
f' best_hyper_parameters={best_params}\n'
|
|
277
|
-
f' best_meta_objective ={best_target}\n')
|
|
520
|
+
f' best_meta_objective ={best_target}\n')
|
|
278
521
|
|
|
279
|
-
|
|
280
|
-
|
|
522
|
+
self.best_params = best_params
|
|
523
|
+
self.optimizer = optimizer
|
|
524
|
+
self.log_file = log_file
|
|
281
525
|
return best_params
|
|
282
|
-
|
|
283
|
-
def _filename(self, name, ext):
|
|
284
|
-
domain_name = ''.join(c for c in self.env.model.domain_name
|
|
285
|
-
if c.isalnum() or c == '_')
|
|
286
|
-
instance_name = ''.join(c for c in self.env.model.instance_name
|
|
287
|
-
if c.isalnum() or c == '_')
|
|
288
|
-
filename = f'{name}_{domain_name}_{instance_name}.{ext}'
|
|
289
|
-
return filename
|
|
290
|
-
|
|
291
|
-
def _save_plot(self, filename):
|
|
292
|
-
try:
|
|
293
|
-
import matplotlib.pyplot as plt
|
|
294
|
-
from sklearn.manifold import MDS
|
|
295
|
-
except Exception as e:
|
|
296
|
-
raise_warning(f'failed to import packages matplotlib or sklearn, '
|
|
297
|
-
f'aborting plot of search space\n{e}', 'red')
|
|
298
|
-
else:
|
|
299
|
-
with open(filename, 'r') as file:
|
|
300
|
-
data_iter = csv.reader(file, delimiter=',')
|
|
301
|
-
data = [row for row in data_iter]
|
|
302
|
-
data = np.asarray(data, dtype=object)
|
|
303
|
-
hparam = data[1:, len(COLUMNS):].astype(np.float64)
|
|
304
|
-
target = data[1:, 3].astype(np.float64)
|
|
305
|
-
target = (target - np.min(target)) / (np.max(target) - np.min(target))
|
|
306
|
-
embedding = MDS(n_components=2, normalized_stress='auto')
|
|
307
|
-
hparam_low = embedding.fit_transform(hparam)
|
|
308
|
-
sc = plt.scatter(hparam_low[:, 0], hparam_low[:, 1], c=target, s=5,
|
|
309
|
-
cmap='seismic', edgecolor='gray', linewidth=0)
|
|
310
|
-
ax = plt.gca()
|
|
311
|
-
for i in range(len(target)):
|
|
312
|
-
ax.annotate(str(i), (hparam_low[i, 0], hparam_low[i, 1]), fontsize=3)
|
|
313
|
-
plt.colorbar(sc)
|
|
314
|
-
plt.savefig(self._filename('gp_points', 'pdf'))
|
|
315
|
-
plt.clf()
|
|
316
|
-
plt.close()
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# ===============================================================================
|
|
320
|
-
#
|
|
321
|
-
# STRAIGHT LINE PLANNING
|
|
322
|
-
#
|
|
323
|
-
# ===============================================================================
|
|
324
|
-
def objective_slp(params, kwargs, key, index):
|
|
325
|
-
|
|
326
|
-
# transform hyper-parameters to natural space
|
|
327
|
-
param_values = [
|
|
328
|
-
pmap(params[name])
|
|
329
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
330
|
-
]
|
|
331
|
-
|
|
332
|
-
# unpack hyper-parameters
|
|
333
|
-
if kwargs['wrapped_bool_actions']:
|
|
334
|
-
std, lr, w, wa = param_values
|
|
335
|
-
else:
|
|
336
|
-
std, lr, w = param_values
|
|
337
|
-
wa = None
|
|
338
|
-
key, subkey = jax.random.split(key)
|
|
339
|
-
if kwargs['verbose']:
|
|
340
|
-
print(f'[{index}] key={subkey[0]}, '
|
|
341
|
-
f'std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
|
|
342
|
-
|
|
343
|
-
# initialize planning algorithm
|
|
344
|
-
planner = JaxBackpropPlanner(
|
|
345
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
346
|
-
plan=JaxStraightLinePlan(
|
|
347
|
-
initializer=jax.nn.initializers.normal(std),
|
|
348
|
-
**kwargs['plan_kwargs']),
|
|
349
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
350
|
-
**kwargs['planner_kwargs'])
|
|
351
|
-
policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
|
|
352
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
353
|
-
|
|
354
|
-
# initialize policy
|
|
355
|
-
policy = JaxOfflineController(
|
|
356
|
-
planner=planner,
|
|
357
|
-
key=subkey,
|
|
358
|
-
eval_hyperparams=policy_hparams,
|
|
359
|
-
train_on_reset=True,
|
|
360
|
-
epochs=kwargs['train_epochs'],
|
|
361
|
-
train_seconds=kwargs['timeout_training'],
|
|
362
|
-
model_params=model_params,
|
|
363
|
-
policy_hyperparams=policy_hparams,
|
|
364
|
-
print_summary=False,
|
|
365
|
-
print_progress=False,
|
|
366
|
-
tqdm_position=index)
|
|
367
|
-
|
|
368
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
369
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
370
|
-
instance=kwargs['instance'],
|
|
371
|
-
vectorized=True,
|
|
372
|
-
enforce_action_constraints=False)
|
|
373
|
-
|
|
374
|
-
# perform training
|
|
375
|
-
average_reward = 0.0
|
|
376
|
-
for trial in range(kwargs['eval_trials']):
|
|
377
|
-
key, subkey = jax.random.split(key)
|
|
378
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
379
|
-
if kwargs['verbose']:
|
|
380
|
-
print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
|
|
381
|
-
f'reward={total_reward}', flush=True)
|
|
382
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
383
|
-
if kwargs['verbose']:
|
|
384
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
385
|
-
return average_reward
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
def power_ten(x):
|
|
389
|
-
return 10.0 ** x
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
class JaxParameterTuningSLP(JaxParameterTuning):
|
|
393
|
-
|
|
394
|
-
def __init__(self, *args,
|
|
395
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
396
|
-
'std': (-5., 2., power_ten),
|
|
397
|
-
'lr': (-5., 2., power_ten),
|
|
398
|
-
'w': (0., 5., power_ten),
|
|
399
|
-
'wa': (0., 5., power_ten)
|
|
400
|
-
},
|
|
401
|
-
**kwargs) -> None:
|
|
402
|
-
'''Creates a new tuning class for straight line planners.
|
|
403
|
-
|
|
404
|
-
:param *args: arguments to pass to parent class
|
|
405
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
406
|
-
weight initialization (std), learning rate (lr), model weight (w), and
|
|
407
|
-
action weight (wa) if wrap_sigmoid and boolean action fluents exist
|
|
408
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
409
|
-
'''
|
|
410
|
-
|
|
411
|
-
super(JaxParameterTuningSLP, self).__init__(
|
|
412
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
413
|
-
|
|
414
|
-
# action parameters required if wrap_sigmoid and boolean action exists
|
|
415
|
-
self.wrapped_bool_actions = []
|
|
416
|
-
if self.plan_kwargs.get('wrap_sigmoid', True):
|
|
417
|
-
for var in self.env.model.action_fluents:
|
|
418
|
-
if self.env.model.variable_ranges[var] == 'bool':
|
|
419
|
-
self.wrapped_bool_actions.append(var)
|
|
420
|
-
if not self.wrapped_bool_actions:
|
|
421
|
-
self.hyperparams_dict.pop('wa', None)
|
|
422
|
-
|
|
423
|
-
def _pickleable_objective_with_kwargs(self):
|
|
424
|
-
objective_fn = objective_slp
|
|
425
|
-
|
|
426
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
427
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
428
|
-
plan_kwargs.pop('initializer', None)
|
|
429
|
-
|
|
430
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
431
|
-
planner_kwargs.pop('rddl', None)
|
|
432
|
-
planner_kwargs.pop('plan', None)
|
|
433
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
434
|
-
|
|
435
|
-
kwargs = {
|
|
436
|
-
'rddl': self.env.model,
|
|
437
|
-
'domain': self.env.domain_text,
|
|
438
|
-
'instance': self.env.instance_text,
|
|
439
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
440
|
-
'timeout_training': self.timeout_training,
|
|
441
|
-
'train_epochs': self.train_epochs,
|
|
442
|
-
'planner_kwargs': planner_kwargs,
|
|
443
|
-
'plan_kwargs': plan_kwargs,
|
|
444
|
-
'verbose': self.verbose,
|
|
445
|
-
'wrapped_bool_actions': self.wrapped_bool_actions,
|
|
446
|
-
'eval_trials': self.eval_trials
|
|
447
|
-
}
|
|
448
|
-
return objective_fn, kwargs
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
# ===============================================================================
|
|
452
|
-
#
|
|
453
|
-
# REPLANNING
|
|
454
|
-
#
|
|
455
|
-
# ===============================================================================
|
|
456
|
-
def objective_replan(params, kwargs, key, index):
|
|
457
|
-
|
|
458
|
-
# transform hyper-parameters to natural space
|
|
459
|
-
param_values = [
|
|
460
|
-
pmap(params[name])
|
|
461
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
462
|
-
]
|
|
463
|
-
|
|
464
|
-
# unpack hyper-parameters
|
|
465
|
-
if kwargs['wrapped_bool_actions']:
|
|
466
|
-
std, lr, w, wa, T = param_values
|
|
467
|
-
else:
|
|
468
|
-
std, lr, w, T = param_values
|
|
469
|
-
wa = None
|
|
470
|
-
key, subkey = jax.random.split(key)
|
|
471
|
-
if kwargs['verbose']:
|
|
472
|
-
print(f'[{index}] key={subkey[0]}, '
|
|
473
|
-
f'std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
|
|
474
|
-
|
|
475
|
-
# initialize planning algorithm
|
|
476
|
-
planner = JaxBackpropPlanner(
|
|
477
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
478
|
-
plan=JaxStraightLinePlan(
|
|
479
|
-
initializer=jax.nn.initializers.normal(std),
|
|
480
|
-
**kwargs['plan_kwargs']),
|
|
481
|
-
rollout_horizon=T,
|
|
482
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
483
|
-
**kwargs['planner_kwargs'])
|
|
484
|
-
policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
|
|
485
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
486
|
-
|
|
487
|
-
# initialize controller
|
|
488
|
-
policy = JaxOnlineController(
|
|
489
|
-
planner=planner,
|
|
490
|
-
key=subkey,
|
|
491
|
-
eval_hyperparams=policy_hparams,
|
|
492
|
-
warm_start=kwargs['use_guess_last_epoch'],
|
|
493
|
-
epochs=kwargs['train_epochs'],
|
|
494
|
-
train_seconds=kwargs['timeout_training'],
|
|
495
|
-
model_params=model_params,
|
|
496
|
-
policy_hyperparams=policy_hparams,
|
|
497
|
-
print_summary=False,
|
|
498
|
-
print_progress=False,
|
|
499
|
-
tqdm_position=index)
|
|
500
|
-
|
|
501
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
502
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
503
|
-
instance=kwargs['instance'],
|
|
504
|
-
vectorized=True,
|
|
505
|
-
enforce_action_constraints=False)
|
|
506
|
-
|
|
507
|
-
# perform training
|
|
508
|
-
average_reward = 0.0
|
|
509
|
-
for trial in range(kwargs['eval_trials']):
|
|
510
|
-
key, subkey = jax.random.split(key)
|
|
511
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
512
|
-
if kwargs['verbose']:
|
|
513
|
-
print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
|
|
514
|
-
f'reward={total_reward}', flush=True)
|
|
515
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
516
|
-
if kwargs['verbose']:
|
|
517
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
518
|
-
return average_reward
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
class JaxParameterTuningSLPReplan(JaxParameterTuningSLP):
|
|
522
|
-
|
|
523
|
-
def __init__(self,
|
|
524
|
-
*args,
|
|
525
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
526
|
-
'std': (-5., 2., power_ten),
|
|
527
|
-
'lr': (-5., 2., power_ten),
|
|
528
|
-
'w': (0., 5., power_ten),
|
|
529
|
-
'wa': (0., 5., power_ten),
|
|
530
|
-
'T': (1, None, int)
|
|
531
|
-
},
|
|
532
|
-
use_guess_last_epoch: bool=True,
|
|
533
|
-
**kwargs) -> None:
|
|
534
|
-
'''Creates a new tuning class for straight line planners.
|
|
535
|
-
|
|
536
|
-
:param *args: arguments to pass to parent class
|
|
537
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
538
|
-
weight initialization (std), learning rate (lr), model weight (w),
|
|
539
|
-
action weight (wa) if wrap_sigmoid and boolean action fluents exist, and
|
|
540
|
-
lookahead horizon (T)
|
|
541
|
-
:param use_guess_last_epoch: use the trained parameters from previous
|
|
542
|
-
decision to warm-start next decision
|
|
543
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
544
|
-
'''
|
|
545
|
-
|
|
546
|
-
super(JaxParameterTuningSLPReplan, self).__init__(
|
|
547
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
548
|
-
|
|
549
|
-
self.use_guess_last_epoch = use_guess_last_epoch
|
|
550
|
-
|
|
551
|
-
# set upper range of lookahead horizon to environment horizon
|
|
552
|
-
if self.hyperparams_dict['T'][1] is None:
|
|
553
|
-
self.hyperparams_dict['T'] = (1, self.env.horizon, int)
|
|
554
|
-
|
|
555
|
-
def _pickleable_objective_with_kwargs(self):
|
|
556
|
-
objective_fn = objective_replan
|
|
557
|
-
|
|
558
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
559
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
560
|
-
plan_kwargs.pop('initializer', None)
|
|
561
|
-
|
|
562
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
563
|
-
planner_kwargs.pop('rddl', None)
|
|
564
|
-
planner_kwargs.pop('plan', None)
|
|
565
|
-
planner_kwargs.pop('rollout_horizon', None)
|
|
566
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
567
|
-
|
|
568
|
-
kwargs = {
|
|
569
|
-
'rddl': self.env.model,
|
|
570
|
-
'domain': self.env.domain_text,
|
|
571
|
-
'instance': self.env.instance_text,
|
|
572
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
573
|
-
'timeout_training': self.timeout_training,
|
|
574
|
-
'train_epochs': self.train_epochs,
|
|
575
|
-
'planner_kwargs': planner_kwargs,
|
|
576
|
-
'plan_kwargs': plan_kwargs,
|
|
577
|
-
'verbose': self.verbose,
|
|
578
|
-
'wrapped_bool_actions': self.wrapped_bool_actions,
|
|
579
|
-
'eval_trials': self.eval_trials,
|
|
580
|
-
'use_guess_last_epoch': self.use_guess_last_epoch
|
|
581
|
-
}
|
|
582
|
-
return objective_fn, kwargs
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
# ===============================================================================
|
|
586
|
-
#
|
|
587
|
-
# DEEP REACTIVE POLICIES
|
|
588
|
-
#
|
|
589
|
-
# ===============================================================================
|
|
590
|
-
def objective_drp(params, kwargs, key, index):
|
|
591
|
-
|
|
592
|
-
# transform hyper-parameters to natural space
|
|
593
|
-
param_values = [
|
|
594
|
-
pmap(params[name])
|
|
595
|
-
for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
|
|
596
|
-
]
|
|
597
|
-
|
|
598
|
-
# unpack hyper-parameters
|
|
599
|
-
lr, w, layers, neurons = param_values
|
|
600
|
-
key, subkey = jax.random.split(key)
|
|
601
|
-
if kwargs['verbose']:
|
|
602
|
-
print(f'[{index}] key={subkey[0]}, '
|
|
603
|
-
f'lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
|
|
604
|
-
|
|
605
|
-
# initialize planning algorithm
|
|
606
|
-
planner = JaxBackpropPlanner(
|
|
607
|
-
rddl=deepcopy(kwargs['rddl']),
|
|
608
|
-
plan=JaxDeepReactivePolicy(
|
|
609
|
-
topology=[neurons] * layers,
|
|
610
|
-
**kwargs['plan_kwargs']),
|
|
611
|
-
optimizer_kwargs={'learning_rate': lr},
|
|
612
|
-
**kwargs['planner_kwargs'])
|
|
613
|
-
policy_hparams = {name: None for name in planner._action_bounds}
|
|
614
|
-
model_params = {name: w for name in planner.compiled.model_params}
|
|
615
|
-
|
|
616
|
-
# initialize policy
|
|
617
|
-
policy = JaxOfflineController(
|
|
618
|
-
planner=planner,
|
|
619
|
-
key=subkey,
|
|
620
|
-
eval_hyperparams=policy_hparams,
|
|
621
|
-
train_on_reset=True,
|
|
622
|
-
epochs=kwargs['train_epochs'],
|
|
623
|
-
train_seconds=kwargs['timeout_training'],
|
|
624
|
-
model_params=model_params,
|
|
625
|
-
policy_hyperparams=policy_hparams,
|
|
626
|
-
print_summary=False,
|
|
627
|
-
print_progress=False,
|
|
628
|
-
tqdm_position=index)
|
|
629
|
-
|
|
630
|
-
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
631
|
-
env = RDDLEnv(domain=kwargs['domain'],
|
|
632
|
-
instance=kwargs['instance'],
|
|
633
|
-
vectorized=True,
|
|
634
|
-
enforce_action_constraints=False)
|
|
635
|
-
|
|
636
|
-
# perform training
|
|
637
|
-
average_reward = 0.0
|
|
638
|
-
for trial in range(kwargs['eval_trials']):
|
|
639
|
-
key, subkey = jax.random.split(key)
|
|
640
|
-
total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
|
|
641
|
-
if kwargs['verbose']:
|
|
642
|
-
print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
|
|
643
|
-
f'reward={total_reward}', flush=True)
|
|
644
|
-
average_reward += total_reward / kwargs['eval_trials']
|
|
645
|
-
if kwargs['verbose']:
|
|
646
|
-
print(f'[{index}] average reward={average_reward}', flush=True)
|
|
647
|
-
return average_reward
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
def power_two_int(x):
|
|
651
|
-
return 2 ** int(x)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
class JaxParameterTuningDRP(JaxParameterTuning):
|
|
655
|
-
|
|
656
|
-
def __init__(self, *args,
|
|
657
|
-
hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
|
|
658
|
-
'lr': (-7., 2., power_ten),
|
|
659
|
-
'w': (0., 5., power_ten),
|
|
660
|
-
'layers': (1., 3., int),
|
|
661
|
-
'neurons': (2., 9., power_two_int)
|
|
662
|
-
},
|
|
663
|
-
**kwargs) -> None:
|
|
664
|
-
'''Creates a new tuning class for deep reactive policies.
|
|
665
|
-
|
|
666
|
-
:param *args: arguments to pass to parent class
|
|
667
|
-
:param hyperparams_dict: same as parent class, but here must contain
|
|
668
|
-
learning rate (lr), model weight (w), number of hidden layers (layers)
|
|
669
|
-
and number of neurons per hidden layer (neurons)
|
|
670
|
-
:param **kwargs: keyword arguments to pass to parent class
|
|
671
|
-
'''
|
|
672
|
-
|
|
673
|
-
super(JaxParameterTuningDRP, self).__init__(
|
|
674
|
-
*args, hyperparams_dict=hyperparams_dict, **kwargs)
|
|
675
|
-
|
|
676
|
-
def _pickleable_objective_with_kwargs(self):
|
|
677
|
-
objective_fn = objective_drp
|
|
678
|
-
|
|
679
|
-
# duplicate planner and plan keyword arguments must be removed
|
|
680
|
-
plan_kwargs = self.plan_kwargs.copy()
|
|
681
|
-
plan_kwargs.pop('topology', None)
|
|
682
|
-
|
|
683
|
-
planner_kwargs = self.planner_kwargs.copy()
|
|
684
|
-
planner_kwargs.pop('rddl', None)
|
|
685
|
-
planner_kwargs.pop('plan', None)
|
|
686
|
-
planner_kwargs.pop('optimizer_kwargs', None)
|
|
687
|
-
|
|
688
|
-
kwargs = {
|
|
689
|
-
'rddl': self.env.model,
|
|
690
|
-
'domain': self.env.domain_text,
|
|
691
|
-
'instance': self.env.instance_text,
|
|
692
|
-
'hyperparams_dict': self.hyperparams_dict,
|
|
693
|
-
'timeout_training': self.timeout_training,
|
|
694
|
-
'train_epochs': self.train_epochs,
|
|
695
|
-
'planner_kwargs': planner_kwargs,
|
|
696
|
-
'plan_kwargs': plan_kwargs,
|
|
697
|
-
'verbose': self.verbose,
|
|
698
|
-
'eval_trials': self.eval_trials
|
|
699
|
-
}
|
|
700
|
-
return objective_fn, kwargs
|