pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,58 @@
1
- from copy import deepcopy
2
1
  import csv
3
2
  import datetime
4
- from multiprocessing import get_context
3
+ import threading
4
+ import multiprocessing
5
5
  import os
6
6
  import time
7
- from typing import Any, Callable, Dict, Optional, Tuple
7
+ from typing import Any, Callable, Dict, Iterable, Optional, Tuple
8
8
  import warnings
9
9
  warnings.filterwarnings("ignore")
10
10
 
11
+ from sklearn.gaussian_process.kernels import Matern, ConstantKernel
11
12
  from bayes_opt import BayesianOptimization
12
13
  from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
13
14
  import jax
14
15
  import numpy as np
15
16
 
16
- from pyRDDLGym.core.debug.exception import raise_warning
17
17
  from pyRDDLGym.core.env import RDDLEnv
18
18
 
19
19
  from pyRDDLGym_jax.core.planner import (
20
20
  JaxBackpropPlanner,
21
- JaxStraightLinePlan,
22
- JaxDeepReactivePolicy,
23
21
  JaxOfflineController,
24
- JaxOnlineController
22
+ JaxOnlineController,
23
+ load_config_from_string
25
24
  )
26
25
 
27
- Kwargs = Dict[str, Any]
26
+ # try to load the dash board
27
+ try:
28
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
29
+ except Exception:
30
+ raise_warning('Failed to load the dashboard visualization tool: '
31
+ 'please make sure you have installed the required packages.',
32
+ 'red')
33
+ traceback.print_exc()
34
+ JaxPlannerDashboard = None
35
+
36
+
37
+ class Hyperparameter:
38
+ '''A generic hyper-parameter of the planner that can be tuned.'''
39
+
40
+ def __init__(self, tag: str, lower_bound: float, upper_bound: float,
41
+ search_to_config_map: Callable) -> None:
42
+ self.tag = tag
43
+ self.lower_bound = lower_bound
44
+ self.upper_bound = upper_bound
45
+ self.search_to_config_map = search_to_config_map
46
+
47
+ def __str__(self) -> str:
48
+ return (f'{self.search_to_config_map.__name__} '
49
+ f': [{self.lower_bound}, {self.upper_bound}] -> {self.tag}')
28
50
 
29
- # ===============================================================================
30
- #
31
- # GENERIC TUNING MODULE
32
- #
33
- # Currently contains three implementations:
34
- # 1. straight line plan
35
- # 2. re-planning
36
- # 3. deep reactive policies
37
- #
38
- # ===============================================================================
51
+
52
+ Kwargs = Dict[str, Any]
53
+ ParameterValues = Dict[str, Any]
54
+ Hyperparameters = Iterable[Hyperparameter]
55
+
39
56
  COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
40
57
 
41
58
 
@@ -43,14 +60,12 @@ class JaxParameterTuning:
43
60
  '''A general-purpose class for tuning a Jax planner.'''
44
61
 
45
62
  def __init__(self, env: RDDLEnv,
46
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]],
47
- train_epochs: int,
48
- timeout_training: float,
49
- timeout_tuning: float=np.inf,
63
+ config_template: str,
64
+ hyperparams: Hyperparameters,
65
+ online: bool,
50
66
  eval_trials: int=5,
51
67
  verbose: bool=True,
52
- planner_kwargs: Optional[Kwargs]=None,
53
- plan_kwargs: Optional[Kwargs]=None,
68
+ timeout_tuning: float=np.inf,
54
69
  pool_context: str='spawn',
55
70
  num_workers: int=1,
56
71
  poll_frequency: float=0.2,
@@ -62,23 +77,18 @@ class JaxParameterTuning:
62
77
  on the given RDDL domain and instance.
63
78
 
64
79
  :param env: the RDDLEnv describing the MDP to optimize
65
- :param hyperparams_dict: dictionary mapping name of each hyperparameter
66
- to a triple, where the first two elements are lower/upper bounds on the
67
- parameter value, and the last is a callable mapping the parameter to its
68
- RDDL equivalent
69
- :param train_epochs: the maximum number of iterations of SGD per
70
- step or trial
71
- :param timeout_training: the maximum amount of time to spend training per
72
- trial/decision step (in seconds)
80
+ :param config_template: base configuration file content to tune: regex
81
+ matches are specified directly in the config and map to keys in the
82
+ hyperparams_dict field
83
+ :param hyperparams: list of hyper-parameters to regex replace in the
84
+ config template during tuning
85
+ :param online: whether the planner is optimized online or offline
73
86
  :param timeout_tuning: the maximum amount of time to spend tuning
74
87
  hyperparameters in general (in seconds)
75
88
  :param eval_trials: how many trials to perform independent training
76
89
  in order to estimate the return for each set of hyper-parameters
77
90
  :param verbose: whether to print intermediate results of tuning
78
- :param planner_kwargs: additional arguments to feed to the planner
79
- :param plan_kwargs: additional arguments to feed to the plan/policy
80
- :param pool_context: context for multiprocessing pool (defaults to
81
- "spawn")
91
+ :param pool_context: context for multiprocessing pool (default "spawn")
82
92
  :param num_workers: how many points to evaluate in parallel
83
93
  :param poll_frequency: how often (in seconds) to poll for completed
84
94
  jobs, necessary if num_workers > 1
@@ -88,21 +98,20 @@ class JaxParameterTuning:
88
98
  during initialization
89
99
  :param gp_params: additional parameters to feed to Bayesian optimizer
90
100
  after initialization optimization
91
- '''
92
-
101
+ '''
102
+ # objective parameters
93
103
  self.env = env
104
+ self.config_template = config_template
105
+ hyperparams_dict = {hyper_param.tag: hyper_param
106
+ for hyper_param in hyperparams
107
+ if hyper_param.tag in config_template}
94
108
  self.hyperparams_dict = hyperparams_dict
95
- self.train_epochs = train_epochs
96
- self.timeout_training = timeout_training
97
- self.timeout_tuning = timeout_tuning
109
+ self.online = online
98
110
  self.eval_trials = eval_trials
99
111
  self.verbose = verbose
100
- if planner_kwargs is None:
101
- planner_kwargs = {}
102
- self.planner_kwargs = planner_kwargs
103
- if plan_kwargs is None:
104
- plan_kwargs = {}
105
- self.plan_kwargs = plan_kwargs
112
+
113
+ # Bayesian parameters
114
+ self.timeout_tuning = timeout_tuning
106
115
  self.pool_context = pool_context
107
116
  self.num_workers = num_workers
108
117
  self.poll_frequency = poll_frequency
@@ -111,20 +120,33 @@ class JaxParameterTuning:
111
120
  gp_init_kwargs = {}
112
121
  self.gp_init_kwargs = gp_init_kwargs
113
122
  if gp_params is None:
114
- gp_params = {'n_restarts_optimizer': 10}
123
+ gp_params = {'n_restarts_optimizer': 25,
124
+ 'kernel': self.make_default_kernel()}
115
125
  self.gp_params = gp_params
116
-
117
- # create acquisition function
118
126
  if acquisition is None:
119
127
  num_samples = self.gp_iters * self.num_workers
120
- acquisition = JaxParameterTuning._annealing_acquisition(num_samples)
128
+ acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
121
129
  self.acquisition = acquisition
122
130
 
131
+ @staticmethod
132
+ def make_default_kernel():
133
+ weight1 = ConstantKernel(1.0, (0.01, 100.0))
134
+ weight2 = ConstantKernel(1.0, (0.01, 100.0))
135
+ weight3 = ConstantKernel(1.0, (0.01, 100.0))
136
+ kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
137
+ kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
138
+ kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
139
+ return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
140
+
123
141
  def summarize_hyperparameters(self) -> None:
142
+ hyper_params_table = []
143
+ for (_, param) in self.hyperparams_dict.items():
144
+ hyper_params_table.append(f' {str(param)}')
145
+ hyper_params_table = '\n'.join(hyper_params_table)
124
146
  print(f'hyperparameter optimizer parameters:\n'
125
- f' tuned_hyper_parameters ={self.hyperparams_dict}\n'
147
+ f' tuned_hyper_parameters =\n{hyper_params_table}\n'
126
148
  f' initialization_args ={self.gp_init_kwargs}\n'
127
- f' additional_args ={self.gp_params}\n'
149
+ f' gp_params ={self.gp_params}\n'
128
150
  f' tuning_iterations ={self.gp_iters}\n'
129
151
  f' tuning_timeout ={self.timeout_tuning}\n'
130
152
  f' tuning_batch_size ={self.num_workers}\n'
@@ -132,43 +154,225 @@ class JaxParameterTuning:
132
154
  f' mp_pool_poll_frequency ={self.poll_frequency}\n'
133
155
  f'meta-objective parameters:\n'
134
156
  f' planning_trials_per_iter ={self.eval_trials}\n'
135
- f' planning_iters_per_trial ={self.train_epochs}\n'
136
- f' planning_timeout_per_trial={self.timeout_training}\n'
137
157
  f' acquisition_fn ={self.acquisition}')
138
158
 
139
159
  @staticmethod
140
- def _annealing_acquisition(n_samples, n_delay_samples=0, kappa1=10.0, kappa2=1.0):
160
+ def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
161
+ kappa1: float=10.0, kappa2: float=1.0) -> UpperConfidenceBound:
141
162
  acq_fn = UpperConfidenceBound(
142
163
  kappa=kappa1,
143
164
  exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
144
- exploration_decay_delay=n_delay_samples)
165
+ exploration_decay_delay=n_delay_samples
166
+ )
145
167
  return acq_fn
146
168
 
147
- def _pickleable_objective_with_kwargs(self):
148
- raise NotImplementedError
169
+ @staticmethod
170
+ def search_to_config_params(hyper_params: Hyperparameters,
171
+ params: ParameterValues) -> ParameterValues:
172
+ config_params = {
173
+ tag: param.search_to_config_map(params[tag])
174
+ for (tag, param) in hyper_params.items()
175
+ }
176
+ return config_params
177
+
178
+ @staticmethod
179
+ def config_from_template(config_template: str,
180
+ config_params: ParameterValues) -> str:
181
+ config_string = config_template
182
+ for (tag, param_value) in config_params.items():
183
+ config_string = config_string.replace(tag, str(param_value))
184
+ return config_string
185
+
186
+ @property
187
+ def best_config(self) -> str:
188
+ return self.config_from_template(self.config_template, self.best_params)
189
+
190
+ @staticmethod
191
+ def queue_listener(queue, dashboard):
192
+ while True:
193
+ args = queue.get()
194
+ if args is None:
195
+ break
196
+ elif len(args) == 2:
197
+ dashboard.update_experiment(*args)
198
+ else:
199
+ dashboard.register_experiment(*args)
200
+
201
+ @staticmethod
202
+ def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
203
+ verbose, viz, queue):
204
+ average_reward = 0.0
205
+ for trial in range(num_trials):
206
+ key, subkey = jax.random.split(key)
207
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
208
+ if queue is not None:
209
+ queue.put((
210
+ experiment_id,
211
+ JaxPlannerDashboard.get_planner_info(planner),
212
+ subkey[0],
213
+ viz
214
+ ))
215
+
216
+ # train the policy
217
+ callback = None
218
+ for callback in planner.optimize_generator(key=subkey, **train_args):
219
+ if queue is not None and queue.empty():
220
+ queue.put((experiment_id, callback))
221
+ best_params = None if callback is None else callback['best_params']
222
+
223
+ # evaluate the policy in the real environment
224
+ policy = JaxOfflineController(
225
+ planner=planner, key=subkey, tqdm_position=index,
226
+ params=best_params, train_on_reset=False)
227
+ total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
228
+
229
+ # update average reward
230
+ if verbose:
231
+ iters = None if callback is None else callback['iteration']
232
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
233
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
234
+ average_reward += total_reward / num_trials
235
+
236
+ if verbose:
237
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
238
+ return average_reward
149
239
 
150
240
  @staticmethod
151
- def _wrapped_evaluate(index, params, key, func, kwargs):
152
- target = func(params=params, kwargs=kwargs, key=key, index=index)
241
+ def online_trials(env, planner, train_args, key, iteration, index, num_trials,
242
+ verbose, viz, queue):
243
+ average_reward = 0.0
244
+ for trial in range(num_trials):
245
+ key, subkey = jax.random.split(key)
246
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
247
+ if queue is not None:
248
+ queue.put((
249
+ experiment_id,
250
+ JaxPlannerDashboard.get_planner_info(planner),
251
+ subkey[0],
252
+ viz
253
+ ))
254
+
255
+ # initialize the online policy
256
+ policy = JaxOnlineController(
257
+ planner=planner, key=subkey, tqdm_position=index, **train_args)
258
+
259
+ # evaluate the policy in the real environment
260
+ total_reward = 0.0
261
+ callback = None
262
+ state, _ = env.reset(seed=np.array(subkey)[0])
263
+ elapsed_time = 0.0
264
+ for step in range(env.horizon):
265
+ action = policy.sample_action(state)
266
+ next_state, reward, terminated, truncated, _ = env.step(action)
267
+ total_reward += reward
268
+ done = terminated or truncated
269
+ state = next_state
270
+ callback = policy.callback
271
+ elapsed_time += callback['elapsed_time']
272
+ callback['iteration'] = step
273
+ callback['progress'] = int(100 * (step + 1.) / env.horizon)
274
+ callback['elapsed_time'] = elapsed_time
275
+ if queue is not None and queue.empty():
276
+ queue.put((experiment_id, callback))
277
+ if done:
278
+ break
279
+
280
+ # update average reward
281
+ if verbose:
282
+ iters = None if callback is None else callback['iteration']
283
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
284
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
285
+ average_reward += total_reward / num_trials
286
+
287
+ if verbose:
288
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
289
+ return average_reward
290
+
291
+ @staticmethod
292
+ def objective_function(params: ParameterValues,
293
+ key: jax.random.PRNGKey,
294
+ index: int,
295
+ iteration: int,
296
+ kwargs: Kwargs,
297
+ queue: object) -> Tuple[ParameterValues, float, int, int]:
298
+ '''A pickleable objective function to evaluate a single hyper-parameter
299
+ configuration.'''
300
+
301
+ hyperparams_dict = kwargs['hyperparams_dict']
302
+ config_template = kwargs['config_template']
303
+ online = kwargs['online']
304
+ domain = kwargs['domain']
305
+ instance = kwargs['instance']
306
+ num_trials = kwargs['eval_trials']
307
+ viz = kwargs['viz']
308
+ verbose = kwargs['verbose']
309
+
310
+ # config string substitution and parsing
311
+ config_params = JaxParameterTuning.search_to_config_params(hyperparams_dict, params)
312
+ if verbose:
313
+ config_param_str = ', '.join(
314
+ f'{k}={v}' for (k, v) in config_params.items())
315
+ print(f'[{index}] key={key[0]}, {config_param_str}', flush=True)
316
+ config_string = JaxParameterTuning.config_from_template(config_template, config_params)
317
+ planner_args, _, train_args = load_config_from_string(config_string)
318
+
319
+ # remove keywords that should not be in the tuner
320
+ train_args.pop('dashboard', None)
321
+
322
+ # initialize env for evaluation (need fresh copy to avoid concurrency)
323
+ env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
324
+
325
+ # run planning algorithm
326
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
327
+ if online:
328
+ average_reward = JaxParameterTuning.online_trials(
329
+ env, planner, train_args, key, iteration, index, num_trials,
330
+ verbose, viz, queue
331
+ )
332
+ else:
333
+ average_reward = JaxParameterTuning.offline_trials(
334
+ env, planner, train_args, key, iteration, index,
335
+ num_trials, verbose, viz, queue
336
+ )
337
+
153
338
  pid = os.getpid()
154
- return index, pid, params, target
155
-
156
- def tune(self, key: jax.random.PRNGKey,
157
- filename: str,
158
- save_plot: bool=False) -> Dict[str, Any]:
339
+ return params, average_reward, index, pid
340
+
341
+ def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
342
+ '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
343
+ print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
344
+
345
+ def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
159
346
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
347
+
160
348
  self.summarize_hyperparameters()
161
349
 
162
- start_time = time.time()
350
+ # clear and prepare output file
351
+ with open(log_file, 'w', newline='') as file:
352
+ writer = csv.writer(file)
353
+ writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
163
354
 
164
- # objective function
165
- objective = self._pickleable_objective_with_kwargs()
166
- evaluate = JaxParameterTuning._wrapped_evaluate
355
+ # create a dash-board for visualizing experiment runs
356
+ if show_dashboard:
357
+ dashboard = JaxPlannerDashboard()
358
+ dashboard.launch()
167
359
 
360
+ # objective function auxiliary data
361
+ obj_kwargs = {
362
+ 'hyperparams_dict': self.hyperparams_dict,
363
+ 'config_template': self.config_template,
364
+ 'online': self.online,
365
+ 'domain': self.env.domain_text,
366
+ 'instance': self.env.instance_text,
367
+ 'eval_trials': self.eval_trials,
368
+ 'viz': self.env._visualizer,
369
+ 'verbose': self.verbose
370
+ }
371
+
168
372
  # create optimizer
169
373
  hyperparams_bounds = {
170
- name: hparam[:2]
171
- for (name, hparam) in self.hyperparams_dict.items()
374
+ tag: (param.lower_bound, param.upper_bound)
375
+ for (tag, param) in self.hyperparams_dict.items()
172
376
  }
173
377
  optimizer = BayesianOptimization(
174
378
  f=None,
@@ -182,91 +386,116 @@ class JaxParameterTuning:
182
386
 
183
387
  # suggest initial parameters to evaluate
184
388
  num_workers = self.num_workers
185
- suggested, acq_params = [], []
389
+ suggested_params, acq_params = [], []
186
390
  for _ in range(num_workers):
187
391
  probe = optimizer.suggest()
188
- suggested.append(probe)
392
+ suggested_params.append(probe)
189
393
  acq_params.append(vars(optimizer.acquisition_function))
190
394
 
191
- # clear and prepare output file
192
- filename = self._filename(filename, 'csv')
193
- with open(filename, 'w', newline='') as file:
194
- writer = csv.writer(file)
195
- writer.writerow(COLUMNS + list(hyperparams_bounds.keys()))
196
-
197
- # start multiprocess evaluation
198
- worker_ids = list(range(num_workers))
199
- best_params, best_target = None, -np.inf
200
-
201
- for it in range(self.gp_iters):
395
+ with multiprocessing.Manager() as manager:
202
396
 
203
- # check if there is enough time left for another iteration
204
- elapsed = time.time() - start_time
205
- if elapsed >= self.timeout_tuning:
206
- print(f'global time limit reached at iteration {it}, aborting')
207
- break
397
+ # queue and parallel thread for handing render events
398
+ if show_dashboard:
399
+ queue = manager.Queue()
400
+ dashboard_thread = threading.Thread(
401
+ target=JaxParameterTuning.queue_listener,
402
+ args=(queue, dashboard)
403
+ )
404
+ dashboard_thread.start()
405
+ else:
406
+ queue = None
208
407
 
209
- # continue with next iteration
210
- print('\n' + '*' * 25 +
211
- f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
212
- f'starting iteration {it + 1}' +
213
- '\n' + '*' * 25)
214
- key, *subkeys = jax.random.split(key, num=num_workers + 1)
215
- rows = [None] * num_workers
408
+ # start multiprocess evaluation
409
+ worker_ids = list(range(num_workers))
410
+ best_params, best_target = None, -np.inf
411
+ key = jax.random.PRNGKey(key)
412
+ start_time = time.time()
216
413
 
217
- # create worker pool: note each iteration must wait for all workers
218
- # to finish before moving to the next
219
- with get_context(self.pool_context).Pool(processes=num_workers) as pool:
414
+ for it in range(self.gp_iters):
220
415
 
221
- # assign jobs to worker pool
222
- # - each trains on suggested parameters from the last iteration
223
- # - this way, since each job finishes asynchronously, these
224
- # parameters usually differ across jobs
225
- results = [
226
- pool.apply_async(evaluate, worker_args + objective)
227
- for worker_args in zip(worker_ids, suggested, subkeys)
228
- ]
229
-
230
- # wait for all workers to complete
231
- while results:
232
- time.sleep(self.poll_frequency)
233
-
234
- # determine which jobs have completed
235
- jobs_done = []
236
- for (i, candidate) in enumerate(results):
237
- if candidate.ready():
238
- jobs_done.append(i)
416
+ # check if there is enough time left for another iteration
417
+ elapsed = time.time() - start_time
418
+ if elapsed >= self.timeout_tuning:
419
+ print(f'global time limit reached at iteration {it}, aborting')
420
+ break
421
+
422
+ # continue with next iteration
423
+ print('\n' + '*' * 80 +
424
+ f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
425
+ f'starting iteration {it + 1}' +
426
+ '\n' + '*' * 80)
427
+ key, *subkeys = jax.random.split(key, num=num_workers + 1)
428
+ rows = [None] * num_workers
429
+ old_best_target = best_target
430
+
431
+ # create worker pool: note each iteration must wait for all workers
432
+ # to finish before moving to the next
433
+ with multiprocessing.get_context(
434
+ self.pool_context).Pool(processes=num_workers) as pool:
239
435
 
240
- # get result from completed jobs
241
- for i in jobs_done[::-1]:
242
-
243
- # extract and register the new evaluation
244
- index, pid, params, target = results.pop(i).get()
245
- optimizer.register(params, target)
246
-
247
- # update acquisition function and suggest a new point
248
- suggested[index] = optimizer.suggest()
249
- old_acq_params = acq_params[index]
250
- acq_params[index] = vars(optimizer.acquisition_function)
251
-
252
- # transform suggestion back to natural space
253
- rddl_params = {
254
- name: pf(params[name])
255
- for (name, (*_, pf)) in self.hyperparams_dict.items()
256
- }
257
-
258
- # update the best suggestion so far
259
- if target > best_target:
260
- best_params, best_target = rddl_params, target
436
+ # assign jobs to worker pool
437
+ results = [
438
+ pool.apply_async(JaxParameterTuning.objective_function,
439
+ obj_args + (it, obj_kwargs, queue))
440
+ for obj_args in zip(suggested_params, subkeys, worker_ids)
441
+ ]
442
+
443
+ # wait for all workers to complete
444
+ while results:
445
+ time.sleep(self.poll_frequency)
261
446
 
262
- # write progress to file in real time
263
- info_i = [pid, index, it, target, best_target, old_acq_params]
264
- rows[index] = info_i + list(rddl_params.values())
447
+ # determine which jobs have completed
448
+ jobs_done = []
449
+ for (i, candidate) in enumerate(results):
450
+ if candidate.ready():
451
+ jobs_done.append(i)
265
452
 
266
- # write results of all processes in current iteration to file
267
- with open(filename, 'a', newline='') as file:
268
- writer = csv.writer(file)
269
- writer.writerows(rows)
453
+ # get result from completed jobs
454
+ for i in jobs_done[::-1]:
455
+
456
+ # extract and register the new evaluation
457
+ params, target, index, pid = results.pop(i).get()
458
+ optimizer.register(params, target)
459
+ optimizer._gp.fit(
460
+ optimizer.space.params, optimizer.space.target)
461
+
462
+ # update acquisition function and suggest a new point
463
+ suggested_params[index] = optimizer.suggest()
464
+ old_acq_params = acq_params[index]
465
+ acq_params[index] = vars(optimizer.acquisition_function)
466
+
467
+ # transform suggestion back to natural space
468
+ config_params = JaxParameterTuning.search_to_config_params(
469
+ self.hyperparams_dict, params)
470
+
471
+ # update the best suggestion so far
472
+ if target > best_target:
473
+ best_params, best_target = config_params, target
474
+
475
+ rows[index] = [pid, index, it, target,
476
+ best_target, old_acq_params] + \
477
+ list(config_params.values())
478
+
479
+ # print best parameter if found
480
+ if best_target > old_best_target:
481
+ print(f'* found new best average reward {best_target:.6f}')
482
+
483
+ # tune the optimizer here
484
+ self.tune_optimizer(optimizer)
485
+
486
+ # write results of all processes in current iteration to file
487
+ with open(log_file, 'a', newline='') as file:
488
+ writer = csv.writer(file)
489
+ writer.writerows(rows)
490
+
491
+ # update the dashboard tuning
492
+ if show_dashboard:
493
+ dashboard.update_tuning(optimizer, hyperparams_bounds)
494
+
495
+ # stop the queue listener thread
496
+ if show_dashboard:
497
+ queue.put(None)
498
+ dashboard_thread.join()
270
499
 
271
500
  # print summary of results
272
501
  elapsed = time.time() - start_time
@@ -274,427 +503,9 @@ class JaxParameterTuning:
274
503
  f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
275
504
  f' iterations ={it + 1}\n'
276
505
  f' best_hyper_parameters={best_params}\n'
277
- f' best_meta_objective ={best_target}\n')
506
+ f' best_meta_objective ={best_target}\n')
278
507
 
279
- if save_plot:
280
- self._save_plot(filename)
508
+ self.best_params = best_params
509
+ self.optimizer = optimizer
510
+ self.log_file = log_file
281
511
  return best_params
282
-
283
- def _filename(self, name, ext):
284
- domain_name = ''.join(c for c in self.env.model.domain_name
285
- if c.isalnum() or c == '_')
286
- instance_name = ''.join(c for c in self.env.model.instance_name
287
- if c.isalnum() or c == '_')
288
- filename = f'{name}_{domain_name}_{instance_name}.{ext}'
289
- return filename
290
-
291
- def _save_plot(self, filename):
292
- try:
293
- import matplotlib.pyplot as plt
294
- from sklearn.manifold import MDS
295
- except Exception as e:
296
- raise_warning(f'failed to import packages matplotlib or sklearn, '
297
- f'aborting plot of search space\n{e}', 'red')
298
- else:
299
- with open(filename, 'r') as file:
300
- data_iter = csv.reader(file, delimiter=',')
301
- data = [row for row in data_iter]
302
- data = np.asarray(data, dtype=object)
303
- hparam = data[1:, len(COLUMNS):].astype(np.float64)
304
- target = data[1:, 3].astype(np.float64)
305
- target = (target - np.min(target)) / (np.max(target) - np.min(target))
306
- embedding = MDS(n_components=2, normalized_stress='auto')
307
- hparam_low = embedding.fit_transform(hparam)
308
- sc = plt.scatter(hparam_low[:, 0], hparam_low[:, 1], c=target, s=5,
309
- cmap='seismic', edgecolor='gray', linewidth=0)
310
- ax = plt.gca()
311
- for i in range(len(target)):
312
- ax.annotate(str(i), (hparam_low[i, 0], hparam_low[i, 1]), fontsize=3)
313
- plt.colorbar(sc)
314
- plt.savefig(self._filename('gp_points', 'pdf'))
315
- plt.clf()
316
- plt.close()
317
-
318
-
319
- # ===============================================================================
320
- #
321
- # STRAIGHT LINE PLANNING
322
- #
323
- # ===============================================================================
324
- def objective_slp(params, kwargs, key, index):
325
-
326
- # transform hyper-parameters to natural space
327
- param_values = [
328
- pmap(params[name])
329
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
330
- ]
331
-
332
- # unpack hyper-parameters
333
- if kwargs['wrapped_bool_actions']:
334
- std, lr, w, wa = param_values
335
- else:
336
- std, lr, w = param_values
337
- wa = None
338
- key, subkey = jax.random.split(key)
339
- if kwargs['verbose']:
340
- print(f'[{index}] key={subkey[0]}, '
341
- f'std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
342
-
343
- # initialize planning algorithm
344
- planner = JaxBackpropPlanner(
345
- rddl=deepcopy(kwargs['rddl']),
346
- plan=JaxStraightLinePlan(
347
- initializer=jax.nn.initializers.normal(std),
348
- **kwargs['plan_kwargs']),
349
- optimizer_kwargs={'learning_rate': lr},
350
- **kwargs['planner_kwargs'])
351
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
352
- model_params = {name: w for name in planner.compiled.model_params}
353
-
354
- # initialize policy
355
- policy = JaxOfflineController(
356
- planner=planner,
357
- key=subkey,
358
- eval_hyperparams=policy_hparams,
359
- train_on_reset=True,
360
- epochs=kwargs['train_epochs'],
361
- train_seconds=kwargs['timeout_training'],
362
- model_params=model_params,
363
- policy_hyperparams=policy_hparams,
364
- print_summary=False,
365
- print_progress=False,
366
- tqdm_position=index)
367
-
368
- # initialize env for evaluation (need fresh copy to avoid concurrency)
369
- env = RDDLEnv(domain=kwargs['domain'],
370
- instance=kwargs['instance'],
371
- vectorized=True,
372
- enforce_action_constraints=False)
373
-
374
- # perform training
375
- average_reward = 0.0
376
- for trial in range(kwargs['eval_trials']):
377
- key, subkey = jax.random.split(key)
378
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
379
- if kwargs['verbose']:
380
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
381
- f'reward={total_reward}', flush=True)
382
- average_reward += total_reward / kwargs['eval_trials']
383
- if kwargs['verbose']:
384
- print(f'[{index}] average reward={average_reward}', flush=True)
385
- return average_reward
386
-
387
-
388
- def power_ten(x):
389
- return 10.0 ** x
390
-
391
-
392
- class JaxParameterTuningSLP(JaxParameterTuning):
393
-
394
- def __init__(self, *args,
395
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
396
- 'std': (-5., 2., power_ten),
397
- 'lr': (-5., 2., power_ten),
398
- 'w': (0., 5., power_ten),
399
- 'wa': (0., 5., power_ten)
400
- },
401
- **kwargs) -> None:
402
- '''Creates a new tuning class for straight line planners.
403
-
404
- :param *args: arguments to pass to parent class
405
- :param hyperparams_dict: same as parent class, but here must contain
406
- weight initialization (std), learning rate (lr), model weight (w), and
407
- action weight (wa) if wrap_sigmoid and boolean action fluents exist
408
- :param **kwargs: keyword arguments to pass to parent class
409
- '''
410
-
411
- super(JaxParameterTuningSLP, self).__init__(
412
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
413
-
414
- # action parameters required if wrap_sigmoid and boolean action exists
415
- self.wrapped_bool_actions = []
416
- if self.plan_kwargs.get('wrap_sigmoid', True):
417
- for var in self.env.model.action_fluents:
418
- if self.env.model.variable_ranges[var] == 'bool':
419
- self.wrapped_bool_actions.append(var)
420
- if not self.wrapped_bool_actions:
421
- self.hyperparams_dict.pop('wa', None)
422
-
423
- def _pickleable_objective_with_kwargs(self):
424
- objective_fn = objective_slp
425
-
426
- # duplicate planner and plan keyword arguments must be removed
427
- plan_kwargs = self.plan_kwargs.copy()
428
- plan_kwargs.pop('initializer', None)
429
-
430
- planner_kwargs = self.planner_kwargs.copy()
431
- planner_kwargs.pop('rddl', None)
432
- planner_kwargs.pop('plan', None)
433
- planner_kwargs.pop('optimizer_kwargs', None)
434
-
435
- kwargs = {
436
- 'rddl': self.env.model,
437
- 'domain': self.env.domain_text,
438
- 'instance': self.env.instance_text,
439
- 'hyperparams_dict': self.hyperparams_dict,
440
- 'timeout_training': self.timeout_training,
441
- 'train_epochs': self.train_epochs,
442
- 'planner_kwargs': planner_kwargs,
443
- 'plan_kwargs': plan_kwargs,
444
- 'verbose': self.verbose,
445
- 'wrapped_bool_actions': self.wrapped_bool_actions,
446
- 'eval_trials': self.eval_trials
447
- }
448
- return objective_fn, kwargs
449
-
450
-
451
- # ===============================================================================
452
- #
453
- # REPLANNING
454
- #
455
- # ===============================================================================
456
- def objective_replan(params, kwargs, key, index):
457
-
458
- # transform hyper-parameters to natural space
459
- param_values = [
460
- pmap(params[name])
461
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
462
- ]
463
-
464
- # unpack hyper-parameters
465
- if kwargs['wrapped_bool_actions']:
466
- std, lr, w, wa, T = param_values
467
- else:
468
- std, lr, w, T = param_values
469
- wa = None
470
- key, subkey = jax.random.split(key)
471
- if kwargs['verbose']:
472
- print(f'[{index}] key={subkey[0]}, '
473
- f'std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
474
-
475
- # initialize planning algorithm
476
- planner = JaxBackpropPlanner(
477
- rddl=deepcopy(kwargs['rddl']),
478
- plan=JaxStraightLinePlan(
479
- initializer=jax.nn.initializers.normal(std),
480
- **kwargs['plan_kwargs']),
481
- rollout_horizon=T,
482
- optimizer_kwargs={'learning_rate': lr},
483
- **kwargs['planner_kwargs'])
484
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
485
- model_params = {name: w for name in planner.compiled.model_params}
486
-
487
- # initialize controller
488
- policy = JaxOnlineController(
489
- planner=planner,
490
- key=subkey,
491
- eval_hyperparams=policy_hparams,
492
- warm_start=kwargs['use_guess_last_epoch'],
493
- epochs=kwargs['train_epochs'],
494
- train_seconds=kwargs['timeout_training'],
495
- model_params=model_params,
496
- policy_hyperparams=policy_hparams,
497
- print_summary=False,
498
- print_progress=False,
499
- tqdm_position=index)
500
-
501
- # initialize env for evaluation (need fresh copy to avoid concurrency)
502
- env = RDDLEnv(domain=kwargs['domain'],
503
- instance=kwargs['instance'],
504
- vectorized=True,
505
- enforce_action_constraints=False)
506
-
507
- # perform training
508
- average_reward = 0.0
509
- for trial in range(kwargs['eval_trials']):
510
- key, subkey = jax.random.split(key)
511
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
512
- if kwargs['verbose']:
513
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
514
- f'reward={total_reward}', flush=True)
515
- average_reward += total_reward / kwargs['eval_trials']
516
- if kwargs['verbose']:
517
- print(f'[{index}] average reward={average_reward}', flush=True)
518
- return average_reward
519
-
520
-
521
- class JaxParameterTuningSLPReplan(JaxParameterTuningSLP):
522
-
523
- def __init__(self,
524
- *args,
525
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
526
- 'std': (-5., 2., power_ten),
527
- 'lr': (-5., 2., power_ten),
528
- 'w': (0., 5., power_ten),
529
- 'wa': (0., 5., power_ten),
530
- 'T': (1, None, int)
531
- },
532
- use_guess_last_epoch: bool=True,
533
- **kwargs) -> None:
534
- '''Creates a new tuning class for straight line planners.
535
-
536
- :param *args: arguments to pass to parent class
537
- :param hyperparams_dict: same as parent class, but here must contain
538
- weight initialization (std), learning rate (lr), model weight (w),
539
- action weight (wa) if wrap_sigmoid and boolean action fluents exist, and
540
- lookahead horizon (T)
541
- :param use_guess_last_epoch: use the trained parameters from previous
542
- decision to warm-start next decision
543
- :param **kwargs: keyword arguments to pass to parent class
544
- '''
545
-
546
- super(JaxParameterTuningSLPReplan, self).__init__(
547
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
548
-
549
- self.use_guess_last_epoch = use_guess_last_epoch
550
-
551
- # set upper range of lookahead horizon to environment horizon
552
- if self.hyperparams_dict['T'][1] is None:
553
- self.hyperparams_dict['T'] = (1, self.env.horizon, int)
554
-
555
- def _pickleable_objective_with_kwargs(self):
556
- objective_fn = objective_replan
557
-
558
- # duplicate planner and plan keyword arguments must be removed
559
- plan_kwargs = self.plan_kwargs.copy()
560
- plan_kwargs.pop('initializer', None)
561
-
562
- planner_kwargs = self.planner_kwargs.copy()
563
- planner_kwargs.pop('rddl', None)
564
- planner_kwargs.pop('plan', None)
565
- planner_kwargs.pop('rollout_horizon', None)
566
- planner_kwargs.pop('optimizer_kwargs', None)
567
-
568
- kwargs = {
569
- 'rddl': self.env.model,
570
- 'domain': self.env.domain_text,
571
- 'instance': self.env.instance_text,
572
- 'hyperparams_dict': self.hyperparams_dict,
573
- 'timeout_training': self.timeout_training,
574
- 'train_epochs': self.train_epochs,
575
- 'planner_kwargs': planner_kwargs,
576
- 'plan_kwargs': plan_kwargs,
577
- 'verbose': self.verbose,
578
- 'wrapped_bool_actions': self.wrapped_bool_actions,
579
- 'eval_trials': self.eval_trials,
580
- 'use_guess_last_epoch': self.use_guess_last_epoch
581
- }
582
- return objective_fn, kwargs
583
-
584
-
585
- # ===============================================================================
586
- #
587
- # DEEP REACTIVE POLICIES
588
- #
589
- # ===============================================================================
590
- def objective_drp(params, kwargs, key, index):
591
-
592
- # transform hyper-parameters to natural space
593
- param_values = [
594
- pmap(params[name])
595
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
596
- ]
597
-
598
- # unpack hyper-parameters
599
- lr, w, layers, neurons = param_values
600
- key, subkey = jax.random.split(key)
601
- if kwargs['verbose']:
602
- print(f'[{index}] key={subkey[0]}, '
603
- f'lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
604
-
605
- # initialize planning algorithm
606
- planner = JaxBackpropPlanner(
607
- rddl=deepcopy(kwargs['rddl']),
608
- plan=JaxDeepReactivePolicy(
609
- topology=[neurons] * layers,
610
- **kwargs['plan_kwargs']),
611
- optimizer_kwargs={'learning_rate': lr},
612
- **kwargs['planner_kwargs'])
613
- policy_hparams = {name: None for name in planner._action_bounds}
614
- model_params = {name: w for name in planner.compiled.model_params}
615
-
616
- # initialize policy
617
- policy = JaxOfflineController(
618
- planner=planner,
619
- key=subkey,
620
- eval_hyperparams=policy_hparams,
621
- train_on_reset=True,
622
- epochs=kwargs['train_epochs'],
623
- train_seconds=kwargs['timeout_training'],
624
- model_params=model_params,
625
- policy_hyperparams=policy_hparams,
626
- print_summary=False,
627
- print_progress=False,
628
- tqdm_position=index)
629
-
630
- # initialize env for evaluation (need fresh copy to avoid concurrency)
631
- env = RDDLEnv(domain=kwargs['domain'],
632
- instance=kwargs['instance'],
633
- vectorized=True,
634
- enforce_action_constraints=False)
635
-
636
- # perform training
637
- average_reward = 0.0
638
- for trial in range(kwargs['eval_trials']):
639
- key, subkey = jax.random.split(key)
640
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
641
- if kwargs['verbose']:
642
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
643
- f'reward={total_reward}', flush=True)
644
- average_reward += total_reward / kwargs['eval_trials']
645
- if kwargs['verbose']:
646
- print(f'[{index}] average reward={average_reward}', flush=True)
647
- return average_reward
648
-
649
-
650
- def power_two_int(x):
651
- return 2 ** int(x)
652
-
653
-
654
- class JaxParameterTuningDRP(JaxParameterTuning):
655
-
656
- def __init__(self, *args,
657
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
658
- 'lr': (-7., 2., power_ten),
659
- 'w': (0., 5., power_ten),
660
- 'layers': (1., 3., int),
661
- 'neurons': (2., 9., power_two_int)
662
- },
663
- **kwargs) -> None:
664
- '''Creates a new tuning class for deep reactive policies.
665
-
666
- :param *args: arguments to pass to parent class
667
- :param hyperparams_dict: same as parent class, but here must contain
668
- learning rate (lr), model weight (w), number of hidden layers (layers)
669
- and number of neurons per hidden layer (neurons)
670
- :param **kwargs: keyword arguments to pass to parent class
671
- '''
672
-
673
- super(JaxParameterTuningDRP, self).__init__(
674
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
675
-
676
- def _pickleable_objective_with_kwargs(self):
677
- objective_fn = objective_drp
678
-
679
- # duplicate planner and plan keyword arguments must be removed
680
- plan_kwargs = self.plan_kwargs.copy()
681
- plan_kwargs.pop('topology', None)
682
-
683
- planner_kwargs = self.planner_kwargs.copy()
684
- planner_kwargs.pop('rddl', None)
685
- planner_kwargs.pop('plan', None)
686
- planner_kwargs.pop('optimizer_kwargs', None)
687
-
688
- kwargs = {
689
- 'rddl': self.env.model,
690
- 'domain': self.env.domain_text,
691
- 'instance': self.env.instance_text,
692
- 'hyperparams_dict': self.hyperparams_dict,
693
- 'timeout_training': self.timeout_training,
694
- 'train_epochs': self.train_epochs,
695
- 'planner_kwargs': planner_kwargs,
696
- 'plan_kwargs': plan_kwargs,
697
- 'verbose': self.verbose,
698
- 'eval_trials': self.eval_trials
699
- }
700
- return objective_fn, kwargs