pyRDDLGym-jax 0.5__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +336 -472
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +392 -567
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/METADATA +167 -111
  36. pyRDDLGym_jax-1.1.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
1
- from copy import deepcopy
2
1
  import csv
3
2
  import datetime
4
- from multiprocessing import get_context
3
+ import threading
4
+ import multiprocessing
5
5
  import os
6
6
  import time
7
- from typing import Any, Callable, Dict, Optional, Tuple
7
+ import traceback
8
+ from typing import Any, Callable, Dict, Iterable, Optional, Tuple
8
9
  import warnings
9
10
  warnings.filterwarnings("ignore")
10
11
 
12
+ from sklearn.gaussian_process.kernels import Matern, ConstantKernel
11
13
  from bayes_opt import BayesianOptimization
12
14
  from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
13
15
  import jax
@@ -18,24 +20,41 @@ from pyRDDLGym.core.env import RDDLEnv
18
20
 
19
21
  from pyRDDLGym_jax.core.planner import (
20
22
  JaxBackpropPlanner,
21
- JaxStraightLinePlan,
22
- JaxDeepReactivePolicy,
23
23
  JaxOfflineController,
24
- JaxOnlineController
24
+ JaxOnlineController,
25
+ load_config_from_string
25
26
  )
26
27
 
27
- Kwargs = Dict[str, Any]
28
+ # try to load the dash board
29
+ try:
30
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
31
+ except Exception:
32
+ raise_warning('Failed to load the dashboard visualization tool: '
33
+ 'please make sure you have installed the required packages.',
34
+ 'red')
35
+ traceback.print_exc()
36
+ JaxPlannerDashboard = None
37
+
38
+
39
+ class Hyperparameter:
40
+ '''A generic hyper-parameter of the planner that can be tuned.'''
41
+
42
+ def __init__(self, tag: str, lower_bound: float, upper_bound: float,
43
+ search_to_config_map: Callable) -> None:
44
+ self.tag = tag
45
+ self.lower_bound = lower_bound
46
+ self.upper_bound = upper_bound
47
+ self.search_to_config_map = search_to_config_map
48
+
49
+ def __str__(self) -> str:
50
+ return (f'{self.search_to_config_map.__name__} '
51
+ f': [{self.lower_bound}, {self.upper_bound}] -> {self.tag}')
28
52
 
29
- # ===============================================================================
30
- #
31
- # GENERIC TUNING MODULE
32
- #
33
- # Currently contains three implementations:
34
- # 1. straight line plan
35
- # 2. re-planning
36
- # 3. deep reactive policies
37
- #
38
- # ===============================================================================
53
+
54
+ Kwargs = Dict[str, Any]
55
+ ParameterValues = Dict[str, Any]
56
+ Hyperparameters = Iterable[Hyperparameter]
57
+
39
58
  COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
40
59
 
41
60
 
@@ -43,14 +62,13 @@ class JaxParameterTuning:
43
62
  '''A general-purpose class for tuning a Jax planner.'''
44
63
 
45
64
  def __init__(self, env: RDDLEnv,
46
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]],
47
- train_epochs: int,
48
- timeout_training: float,
49
- timeout_tuning: float=np.inf,
65
+ config_template: str,
66
+ hyperparams: Hyperparameters,
67
+ online: bool,
50
68
  eval_trials: int=5,
69
+ rollouts_per_trial: int=1,
51
70
  verbose: bool=True,
52
- planner_kwargs: Optional[Kwargs]=None,
53
- plan_kwargs: Optional[Kwargs]=None,
71
+ timeout_tuning: float=np.inf,
54
72
  pool_context: str='spawn',
55
73
  num_workers: int=1,
56
74
  poll_frequency: float=0.2,
@@ -62,23 +80,20 @@ class JaxParameterTuning:
62
80
  on the given RDDL domain and instance.
63
81
 
64
82
  :param env: the RDDLEnv describing the MDP to optimize
65
- :param hyperparams_dict: dictionary mapping name of each hyperparameter
66
- to a triple, where the first two elements are lower/upper bounds on the
67
- parameter value, and the last is a callable mapping the parameter to its
68
- RDDL equivalent
69
- :param train_epochs: the maximum number of iterations of SGD per
70
- step or trial
71
- :param timeout_training: the maximum amount of time to spend training per
72
- trial/decision step (in seconds)
83
+ :param config_template: base configuration file content to tune: regex
84
+ matches are specified directly in the config and map to keys in the
85
+ hyperparams_dict field
86
+ :param hyperparams: list of hyper-parameters to regex replace in the
87
+ config template during tuning
88
+ :param online: whether the planner is optimized online or offline
73
89
  :param timeout_tuning: the maximum amount of time to spend tuning
74
90
  hyperparameters in general (in seconds)
75
91
  :param eval_trials: how many trials to perform independent training
76
92
  in order to estimate the return for each set of hyper-parameters
93
+ :param rollouts_per_trial: how many rollouts to perform during evaluation
94
+ at the end of each training trial (only applies when online=False)
77
95
  :param verbose: whether to print intermediate results of tuning
78
- :param planner_kwargs: additional arguments to feed to the planner
79
- :param plan_kwargs: additional arguments to feed to the plan/policy
80
- :param pool_context: context for multiprocessing pool (defaults to
81
- "spawn")
96
+ :param pool_context: context for multiprocessing pool (default "spawn")
82
97
  :param num_workers: how many points to evaluate in parallel
83
98
  :param poll_frequency: how often (in seconds) to poll for completed
84
99
  jobs, necessary if num_workers > 1
@@ -88,21 +103,21 @@ class JaxParameterTuning:
88
103
  during initialization
89
104
  :param gp_params: additional parameters to feed to Bayesian optimizer
90
105
  after initialization optimization
91
- '''
92
-
106
+ '''
107
+ # objective parameters
93
108
  self.env = env
109
+ self.config_template = config_template
110
+ hyperparams_dict = {hyper_param.tag: hyper_param
111
+ for hyper_param in hyperparams
112
+ if hyper_param.tag in config_template}
94
113
  self.hyperparams_dict = hyperparams_dict
95
- self.train_epochs = train_epochs
96
- self.timeout_training = timeout_training
97
- self.timeout_tuning = timeout_tuning
114
+ self.online = online
98
115
  self.eval_trials = eval_trials
116
+ self.rollouts_per_trial = rollouts_per_trial
99
117
  self.verbose = verbose
100
- if planner_kwargs is None:
101
- planner_kwargs = {}
102
- self.planner_kwargs = planner_kwargs
103
- if plan_kwargs is None:
104
- plan_kwargs = {}
105
- self.plan_kwargs = plan_kwargs
118
+
119
+ # Bayesian parameters
120
+ self.timeout_tuning = timeout_tuning
106
121
  self.pool_context = pool_context
107
122
  self.num_workers = num_workers
108
123
  self.poll_frequency = poll_frequency
@@ -111,20 +126,33 @@ class JaxParameterTuning:
111
126
  gp_init_kwargs = {}
112
127
  self.gp_init_kwargs = gp_init_kwargs
113
128
  if gp_params is None:
114
- gp_params = {'n_restarts_optimizer': 10}
129
+ gp_params = {'n_restarts_optimizer': 25,
130
+ 'kernel': self.make_default_kernel()}
115
131
  self.gp_params = gp_params
116
-
117
- # create acquisition function
118
132
  if acquisition is None:
119
133
  num_samples = self.gp_iters * self.num_workers
120
- acquisition = JaxParameterTuning._annealing_acquisition(num_samples)
134
+ acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
121
135
  self.acquisition = acquisition
122
136
 
137
+ @staticmethod
138
+ def make_default_kernel():
139
+ weight1 = ConstantKernel(1.0, (0.01, 100.0))
140
+ weight2 = ConstantKernel(1.0, (0.01, 100.0))
141
+ weight3 = ConstantKernel(1.0, (0.01, 100.0))
142
+ kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
143
+ kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
144
+ kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
145
+ return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
146
+
123
147
  def summarize_hyperparameters(self) -> None:
148
+ hyper_params_table = []
149
+ for (_, param) in self.hyperparams_dict.items():
150
+ hyper_params_table.append(f' {str(param)}')
151
+ hyper_params_table = '\n'.join(hyper_params_table)
124
152
  print(f'hyperparameter optimizer parameters:\n'
125
- f' tuned_hyper_parameters ={self.hyperparams_dict}\n'
153
+ f' tuned_hyper_parameters =\n{hyper_params_table}\n'
126
154
  f' initialization_args ={self.gp_init_kwargs}\n'
127
- f' additional_args ={self.gp_params}\n'
155
+ f' gp_params ={self.gp_params}\n'
128
156
  f' tuning_iterations ={self.gp_iters}\n'
129
157
  f' tuning_timeout ={self.timeout_tuning}\n'
130
158
  f' tuning_batch_size ={self.num_workers}\n'
@@ -132,43 +160,233 @@ class JaxParameterTuning:
132
160
  f' mp_pool_poll_frequency ={self.poll_frequency}\n'
133
161
  f'meta-objective parameters:\n'
134
162
  f' planning_trials_per_iter ={self.eval_trials}\n'
135
- f' planning_iters_per_trial ={self.train_epochs}\n'
136
- f' planning_timeout_per_trial={self.timeout_training}\n'
163
+ f' rollouts_per_trial ={self.rollouts_per_trial}\n'
137
164
  f' acquisition_fn ={self.acquisition}')
138
165
 
139
166
  @staticmethod
140
- def _annealing_acquisition(n_samples, n_delay_samples=0, kappa1=10.0, kappa2=1.0):
167
+ def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
168
+ kappa1: float=10.0, kappa2: float=1.0) -> UpperConfidenceBound:
141
169
  acq_fn = UpperConfidenceBound(
142
170
  kappa=kappa1,
143
171
  exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
144
- exploration_decay_delay=n_delay_samples)
172
+ exploration_decay_delay=n_delay_samples
173
+ )
145
174
  return acq_fn
146
175
 
147
- def _pickleable_objective_with_kwargs(self):
148
- raise NotImplementedError
176
+ @staticmethod
177
+ def search_to_config_params(hyper_params: Hyperparameters,
178
+ params: ParameterValues) -> ParameterValues:
179
+ config_params = {
180
+ tag: param.search_to_config_map(params[tag])
181
+ for (tag, param) in hyper_params.items()
182
+ }
183
+ return config_params
184
+
185
+ @staticmethod
186
+ def config_from_template(config_template: str,
187
+ config_params: ParameterValues) -> str:
188
+ config_string = config_template
189
+ for (tag, param_value) in config_params.items():
190
+ config_string = config_string.replace(tag, str(param_value))
191
+ return config_string
192
+
193
+ @property
194
+ def best_config(self) -> str:
195
+ return self.config_from_template(self.config_template, self.best_params)
196
+
197
+ @staticmethod
198
+ def queue_listener(queue, dashboard):
199
+ while True:
200
+ args = queue.get()
201
+ if args is None:
202
+ break
203
+ elif len(args) == 2:
204
+ dashboard.update_experiment(*args)
205
+ else:
206
+ dashboard.register_experiment(*args)
149
207
 
150
208
  @staticmethod
151
- def _wrapped_evaluate(index, params, key, func, kwargs):
152
- target = func(params=params, kwargs=kwargs, key=key, index=index)
209
+ def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
210
+ rollouts_per_trial, verbose, viz, queue):
211
+ average_reward = 0.0
212
+ for trial in range(num_trials):
213
+ key, subkey = jax.random.split(key)
214
+
215
+ # for the dashboard
216
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
217
+ if queue is not None and JaxPlannerDashboard is not None:
218
+ queue.put((
219
+ experiment_id,
220
+ JaxPlannerDashboard.get_planner_info(planner),
221
+ subkey[0],
222
+ viz
223
+ ))
224
+
225
+ # train the policy
226
+ callback = None
227
+ for callback in planner.optimize_generator(key=subkey, **train_args):
228
+ if queue is not None and queue.empty():
229
+ queue.put((experiment_id, callback))
230
+ best_params = None if callback is None else callback['best_params']
231
+
232
+ # evaluate the policy in the real environment
233
+ policy = JaxOfflineController(
234
+ planner=planner, key=subkey, tqdm_position=index,
235
+ params=best_params, train_on_reset=False)
236
+ total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
237
+ seed=np.array(subkey)[0])['mean']
238
+
239
+ # update average reward
240
+ if verbose:
241
+ iters = None if callback is None else callback['iteration']
242
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
243
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
244
+ average_reward += total_reward / num_trials
245
+
246
+ if verbose:
247
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
248
+ return average_reward
249
+
250
+ @staticmethod
251
+ def online_trials(env, planner, train_args, key, iteration, index, num_trials,
252
+ verbose, viz, queue):
253
+ average_reward = 0.0
254
+ for trial in range(num_trials):
255
+ key, subkey = jax.random.split(key)
256
+
257
+ # for the dashboard
258
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
259
+ if queue is not None and JaxPlannerDashboard is not None:
260
+ queue.put((
261
+ experiment_id,
262
+ JaxPlannerDashboard.get_planner_info(planner),
263
+ subkey[0],
264
+ viz
265
+ ))
266
+
267
+ # initialize the online policy
268
+ policy = JaxOnlineController(
269
+ planner=planner, key=subkey, tqdm_position=index, **train_args)
270
+
271
+ # evaluate the policy in the real environment
272
+ total_reward = 0.0
273
+ callback = None
274
+ state, _ = env.reset(seed=np.array(subkey)[0])
275
+ elapsed_time = 0.0
276
+ for step in range(env.horizon):
277
+ action = policy.sample_action(state)
278
+ next_state, reward, terminated, truncated, _ = env.step(action)
279
+ total_reward += reward
280
+ done = terminated or truncated
281
+ state = next_state
282
+ callback = policy.callback
283
+ elapsed_time += callback['elapsed_time']
284
+ callback['iteration'] = step
285
+ callback['progress'] = int(100 * (step + 1.) / env.horizon)
286
+ callback['elapsed_time'] = elapsed_time
287
+ if queue is not None and queue.empty():
288
+ queue.put((experiment_id, callback))
289
+ if done:
290
+ break
291
+
292
+ # update average reward
293
+ if verbose:
294
+ iters = None if callback is None else callback['iteration']
295
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
296
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
297
+ average_reward += total_reward / num_trials
298
+
299
+ if verbose:
300
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
301
+ return average_reward
302
+
303
+ @staticmethod
304
+ def objective_function(params: ParameterValues,
305
+ key: jax.random.PRNGKey,
306
+ index: int,
307
+ iteration: int,
308
+ kwargs: Kwargs,
309
+ queue: object) -> Tuple[ParameterValues, float, int, int]:
310
+ '''A pickleable objective function to evaluate a single hyper-parameter
311
+ configuration.'''
312
+
313
+ hyperparams_dict = kwargs['hyperparams_dict']
314
+ config_template = kwargs['config_template']
315
+ online = kwargs['online']
316
+ domain = kwargs['domain']
317
+ instance = kwargs['instance']
318
+ num_trials = kwargs['eval_trials']
319
+ rollouts_per_trial = kwargs['rollouts_per_trial']
320
+ viz = kwargs['viz']
321
+ verbose = kwargs['verbose']
322
+
323
+ # config string substitution and parsing
324
+ config_params = JaxParameterTuning.search_to_config_params(hyperparams_dict, params)
325
+ if verbose:
326
+ config_param_str = ', '.join(
327
+ f'{k}={v}' for (k, v) in config_params.items())
328
+ print(f'[{index}] key={key[0]}, {config_param_str}', flush=True)
329
+ config_string = JaxParameterTuning.config_from_template(config_template, config_params)
330
+ planner_args, _, train_args = load_config_from_string(config_string)
331
+
332
+ # remove keywords that should not be in the tuner
333
+ train_args.pop('dashboard', None)
334
+
335
+ # initialize env for evaluation (need fresh copy to avoid concurrency)
336
+ env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
337
+
338
+ # run planning algorithm
339
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
340
+ if online:
341
+ average_reward = JaxParameterTuning.online_trials(
342
+ env, planner, train_args, key, iteration, index, num_trials,
343
+ verbose, viz, queue
344
+ )
345
+ else:
346
+ average_reward = JaxParameterTuning.offline_trials(
347
+ env, planner, train_args, key, iteration, index,
348
+ num_trials, rollouts_per_trial, verbose, viz, queue
349
+ )
350
+
153
351
  pid = os.getpid()
154
- return index, pid, params, target
155
-
156
- def tune(self, key: jax.random.PRNGKey,
157
- filename: str,
158
- save_plot: bool=False) -> Dict[str, Any]:
352
+ return params, average_reward, index, pid
353
+
354
+ def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
355
+ '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
356
+ print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
357
+
358
+ def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
159
359
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
360
+
160
361
  self.summarize_hyperparameters()
161
362
 
162
- start_time = time.time()
363
+ # clear and prepare output file
364
+ with open(log_file, 'w', newline='') as file:
365
+ writer = csv.writer(file)
366
+ writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
163
367
 
164
- # objective function
165
- objective = self._pickleable_objective_with_kwargs()
166
- evaluate = JaxParameterTuning._wrapped_evaluate
368
+ # create a dash-board for visualizing experiment runs
369
+ if show_dashboard and JaxPlannerDashboard is not None:
370
+ dashboard = JaxPlannerDashboard()
371
+ dashboard.launch()
167
372
 
373
+ # objective function auxiliary data
374
+ obj_kwargs = {
375
+ 'hyperparams_dict': self.hyperparams_dict,
376
+ 'config_template': self.config_template,
377
+ 'online': self.online,
378
+ 'domain': self.env.domain_text,
379
+ 'instance': self.env.instance_text,
380
+ 'eval_trials': self.eval_trials,
381
+ 'rollouts_per_trial': self.rollouts_per_trial,
382
+ 'viz': self.env._visualizer,
383
+ 'verbose': self.verbose
384
+ }
385
+
168
386
  # create optimizer
169
387
  hyperparams_bounds = {
170
- name: hparam[:2]
171
- for (name, hparam) in self.hyperparams_dict.items()
388
+ tag: (param.lower_bound, param.upper_bound)
389
+ for (tag, param) in self.hyperparams_dict.items()
172
390
  }
173
391
  optimizer = BayesianOptimization(
174
392
  f=None,
@@ -182,91 +400,116 @@ class JaxParameterTuning:
182
400
 
183
401
  # suggest initial parameters to evaluate
184
402
  num_workers = self.num_workers
185
- suggested, acq_params = [], []
403
+ suggested_params, acq_params = [], []
186
404
  for _ in range(num_workers):
187
405
  probe = optimizer.suggest()
188
- suggested.append(probe)
406
+ suggested_params.append(probe)
189
407
  acq_params.append(vars(optimizer.acquisition_function))
190
408
 
191
- # clear and prepare output file
192
- filename = self._filename(filename, 'csv')
193
- with open(filename, 'w', newline='') as file:
194
- writer = csv.writer(file)
195
- writer.writerow(COLUMNS + list(hyperparams_bounds.keys()))
196
-
197
- # start multiprocess evaluation
198
- worker_ids = list(range(num_workers))
199
- best_params, best_target = None, -np.inf
200
-
201
- for it in range(self.gp_iters):
409
+ with multiprocessing.Manager() as manager:
202
410
 
203
- # check if there is enough time left for another iteration
204
- elapsed = time.time() - start_time
205
- if elapsed >= self.timeout_tuning:
206
- print(f'global time limit reached at iteration {it}, aborting')
207
- break
411
+ # queue and parallel thread for handing render events
412
+ if show_dashboard:
413
+ queue = manager.Queue()
414
+ dashboard_thread = threading.Thread(
415
+ target=JaxParameterTuning.queue_listener,
416
+ args=(queue, dashboard)
417
+ )
418
+ dashboard_thread.start()
419
+ else:
420
+ queue = None
208
421
 
209
- # continue with next iteration
210
- print('\n' + '*' * 25 +
211
- f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
212
- f'starting iteration {it + 1}' +
213
- '\n' + '*' * 25)
214
- key, *subkeys = jax.random.split(key, num=num_workers + 1)
215
- rows = [None] * num_workers
422
+ # start multiprocess evaluation
423
+ worker_ids = list(range(num_workers))
424
+ best_params, best_target = None, -np.inf
425
+ key = jax.random.PRNGKey(key)
426
+ start_time = time.time()
216
427
 
217
- # create worker pool: note each iteration must wait for all workers
218
- # to finish before moving to the next
219
- with get_context(self.pool_context).Pool(processes=num_workers) as pool:
428
+ for it in range(self.gp_iters):
220
429
 
221
- # assign jobs to worker pool
222
- # - each trains on suggested parameters from the last iteration
223
- # - this way, since each job finishes asynchronously, these
224
- # parameters usually differ across jobs
225
- results = [
226
- pool.apply_async(evaluate, worker_args + objective)
227
- for worker_args in zip(worker_ids, suggested, subkeys)
228
- ]
229
-
230
- # wait for all workers to complete
231
- while results:
232
- time.sleep(self.poll_frequency)
233
-
234
- # determine which jobs have completed
235
- jobs_done = []
236
- for (i, candidate) in enumerate(results):
237
- if candidate.ready():
238
- jobs_done.append(i)
430
+ # check if there is enough time left for another iteration
431
+ elapsed = time.time() - start_time
432
+ if elapsed >= self.timeout_tuning:
433
+ print(f'global time limit reached at iteration {it}, aborting')
434
+ break
435
+
436
+ # continue with next iteration
437
+ print('\n' + '*' * 80 +
438
+ f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
439
+ f'starting iteration {it + 1}' +
440
+ '\n' + '*' * 80)
441
+ key, *subkeys = jax.random.split(key, num=num_workers + 1)
442
+ rows = [None] * num_workers
443
+ old_best_target = best_target
444
+
445
+ # create worker pool: note each iteration must wait for all workers
446
+ # to finish before moving to the next
447
+ with multiprocessing.get_context(
448
+ self.pool_context).Pool(processes=num_workers) as pool:
239
449
 
240
- # get result from completed jobs
241
- for i in jobs_done[::-1]:
242
-
243
- # extract and register the new evaluation
244
- index, pid, params, target = results.pop(i).get()
245
- optimizer.register(params, target)
246
-
247
- # update acquisition function and suggest a new point
248
- suggested[index] = optimizer.suggest()
249
- old_acq_params = acq_params[index]
250
- acq_params[index] = vars(optimizer.acquisition_function)
251
-
252
- # transform suggestion back to natural space
253
- rddl_params = {
254
- name: pf(params[name])
255
- for (name, (*_, pf)) in self.hyperparams_dict.items()
256
- }
257
-
258
- # update the best suggestion so far
259
- if target > best_target:
260
- best_params, best_target = rddl_params, target
450
+ # assign jobs to worker pool
451
+ results = [
452
+ pool.apply_async(JaxParameterTuning.objective_function,
453
+ obj_args + (it, obj_kwargs, queue))
454
+ for obj_args in zip(suggested_params, subkeys, worker_ids)
455
+ ]
456
+
457
+ # wait for all workers to complete
458
+ while results:
459
+ time.sleep(self.poll_frequency)
261
460
 
262
- # write progress to file in real time
263
- info_i = [pid, index, it, target, best_target, old_acq_params]
264
- rows[index] = info_i + list(rddl_params.values())
461
+ # determine which jobs have completed
462
+ jobs_done = []
463
+ for (i, candidate) in enumerate(results):
464
+ if candidate.ready():
465
+ jobs_done.append(i)
265
466
 
266
- # write results of all processes in current iteration to file
267
- with open(filename, 'a', newline='') as file:
268
- writer = csv.writer(file)
269
- writer.writerows(rows)
467
+ # get result from completed jobs
468
+ for i in jobs_done[::-1]:
469
+
470
+ # extract and register the new evaluation
471
+ params, target, index, pid = results.pop(i).get()
472
+ optimizer.register(params, target)
473
+ optimizer._gp.fit(
474
+ optimizer.space.params, optimizer.space.target)
475
+
476
+ # update acquisition function and suggest a new point
477
+ suggested_params[index] = optimizer.suggest()
478
+ old_acq_params = acq_params[index]
479
+ acq_params[index] = vars(optimizer.acquisition_function)
480
+
481
+ # transform suggestion back to natural space
482
+ config_params = JaxParameterTuning.search_to_config_params(
483
+ self.hyperparams_dict, params)
484
+
485
+ # update the best suggestion so far
486
+ if target > best_target:
487
+ best_params, best_target = config_params, target
488
+
489
+ rows[index] = [pid, index, it, target,
490
+ best_target, old_acq_params] + \
491
+ list(config_params.values())
492
+
493
+ # print best parameter if found
494
+ if best_target > old_best_target:
495
+ print(f'* found new best average reward {best_target:.6f}')
496
+
497
+ # tune the optimizer here
498
+ self.tune_optimizer(optimizer)
499
+
500
+ # write results of all processes in current iteration to file
501
+ with open(log_file, 'a', newline='') as file:
502
+ writer = csv.writer(file)
503
+ writer.writerows(rows)
504
+
505
+ # update the dashboard tuning
506
+ if show_dashboard:
507
+ dashboard.update_tuning(optimizer, hyperparams_bounds)
508
+
509
+ # stop the queue listener thread
510
+ if show_dashboard:
511
+ queue.put(None)
512
+ dashboard_thread.join()
270
513
 
271
514
  # print summary of results
272
515
  elapsed = time.time() - start_time
@@ -274,427 +517,9 @@ class JaxParameterTuning:
274
517
  f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
275
518
  f' iterations ={it + 1}\n'
276
519
  f' best_hyper_parameters={best_params}\n'
277
- f' best_meta_objective ={best_target}\n')
520
+ f' best_meta_objective ={best_target}\n')
278
521
 
279
- if save_plot:
280
- self._save_plot(filename)
522
+ self.best_params = best_params
523
+ self.optimizer = optimizer
524
+ self.log_file = log_file
281
525
  return best_params
282
-
283
- def _filename(self, name, ext):
284
- domain_name = ''.join(c for c in self.env.model.domain_name
285
- if c.isalnum() or c == '_')
286
- instance_name = ''.join(c for c in self.env.model.instance_name
287
- if c.isalnum() or c == '_')
288
- filename = f'{name}_{domain_name}_{instance_name}.{ext}'
289
- return filename
290
-
291
- def _save_plot(self, filename):
292
- try:
293
- import matplotlib.pyplot as plt
294
- from sklearn.manifold import MDS
295
- except Exception as e:
296
- raise_warning(f'failed to import packages matplotlib or sklearn, '
297
- f'aborting plot of search space\n{e}', 'red')
298
- else:
299
- with open(filename, 'r') as file:
300
- data_iter = csv.reader(file, delimiter=',')
301
- data = [row for row in data_iter]
302
- data = np.asarray(data, dtype=object)
303
- hparam = data[1:, len(COLUMNS):].astype(np.float64)
304
- target = data[1:, 3].astype(np.float64)
305
- target = (target - np.min(target)) / (np.max(target) - np.min(target))
306
- embedding = MDS(n_components=2, normalized_stress='auto')
307
- hparam_low = embedding.fit_transform(hparam)
308
- sc = plt.scatter(hparam_low[:, 0], hparam_low[:, 1], c=target, s=5,
309
- cmap='seismic', edgecolor='gray', linewidth=0)
310
- ax = plt.gca()
311
- for i in range(len(target)):
312
- ax.annotate(str(i), (hparam_low[i, 0], hparam_low[i, 1]), fontsize=3)
313
- plt.colorbar(sc)
314
- plt.savefig(self._filename('gp_points', 'pdf'))
315
- plt.clf()
316
- plt.close()
317
-
318
-
319
- # ===============================================================================
320
- #
321
- # STRAIGHT LINE PLANNING
322
- #
323
- # ===============================================================================
324
- def objective_slp(params, kwargs, key, index):
325
-
326
- # transform hyper-parameters to natural space
327
- param_values = [
328
- pmap(params[name])
329
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
330
- ]
331
-
332
- # unpack hyper-parameters
333
- if kwargs['wrapped_bool_actions']:
334
- std, lr, w, wa = param_values
335
- else:
336
- std, lr, w = param_values
337
- wa = None
338
- key, subkey = jax.random.split(key)
339
- if kwargs['verbose']:
340
- print(f'[{index}] key={subkey[0]}, '
341
- f'std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
342
-
343
- # initialize planning algorithm
344
- planner = JaxBackpropPlanner(
345
- rddl=deepcopy(kwargs['rddl']),
346
- plan=JaxStraightLinePlan(
347
- initializer=jax.nn.initializers.normal(std),
348
- **kwargs['plan_kwargs']),
349
- optimizer_kwargs={'learning_rate': lr},
350
- **kwargs['planner_kwargs'])
351
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
352
- model_params = {name: w for name in planner.compiled.model_params}
353
-
354
- # initialize policy
355
- policy = JaxOfflineController(
356
- planner=planner,
357
- key=subkey,
358
- eval_hyperparams=policy_hparams,
359
- train_on_reset=True,
360
- epochs=kwargs['train_epochs'],
361
- train_seconds=kwargs['timeout_training'],
362
- model_params=model_params,
363
- policy_hyperparams=policy_hparams,
364
- print_summary=False,
365
- print_progress=False,
366
- tqdm_position=index)
367
-
368
- # initialize env for evaluation (need fresh copy to avoid concurrency)
369
- env = RDDLEnv(domain=kwargs['domain'],
370
- instance=kwargs['instance'],
371
- vectorized=True,
372
- enforce_action_constraints=False)
373
-
374
- # perform training
375
- average_reward = 0.0
376
- for trial in range(kwargs['eval_trials']):
377
- key, subkey = jax.random.split(key)
378
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
379
- if kwargs['verbose']:
380
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
381
- f'reward={total_reward}', flush=True)
382
- average_reward += total_reward / kwargs['eval_trials']
383
- if kwargs['verbose']:
384
- print(f'[{index}] average reward={average_reward}', flush=True)
385
- return average_reward
386
-
387
-
388
- def power_ten(x):
389
- return 10.0 ** x
390
-
391
-
392
- class JaxParameterTuningSLP(JaxParameterTuning):
393
-
394
- def __init__(self, *args,
395
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
396
- 'std': (-5., 2., power_ten),
397
- 'lr': (-5., 2., power_ten),
398
- 'w': (0., 5., power_ten),
399
- 'wa': (0., 5., power_ten)
400
- },
401
- **kwargs) -> None:
402
- '''Creates a new tuning class for straight line planners.
403
-
404
- :param *args: arguments to pass to parent class
405
- :param hyperparams_dict: same as parent class, but here must contain
406
- weight initialization (std), learning rate (lr), model weight (w), and
407
- action weight (wa) if wrap_sigmoid and boolean action fluents exist
408
- :param **kwargs: keyword arguments to pass to parent class
409
- '''
410
-
411
- super(JaxParameterTuningSLP, self).__init__(
412
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
413
-
414
- # action parameters required if wrap_sigmoid and boolean action exists
415
- self.wrapped_bool_actions = []
416
- if self.plan_kwargs.get('wrap_sigmoid', True):
417
- for var in self.env.model.action_fluents:
418
- if self.env.model.variable_ranges[var] == 'bool':
419
- self.wrapped_bool_actions.append(var)
420
- if not self.wrapped_bool_actions:
421
- self.hyperparams_dict.pop('wa', None)
422
-
423
- def _pickleable_objective_with_kwargs(self):
424
- objective_fn = objective_slp
425
-
426
- # duplicate planner and plan keyword arguments must be removed
427
- plan_kwargs = self.plan_kwargs.copy()
428
- plan_kwargs.pop('initializer', None)
429
-
430
- planner_kwargs = self.planner_kwargs.copy()
431
- planner_kwargs.pop('rddl', None)
432
- planner_kwargs.pop('plan', None)
433
- planner_kwargs.pop('optimizer_kwargs', None)
434
-
435
- kwargs = {
436
- 'rddl': self.env.model,
437
- 'domain': self.env.domain_text,
438
- 'instance': self.env.instance_text,
439
- 'hyperparams_dict': self.hyperparams_dict,
440
- 'timeout_training': self.timeout_training,
441
- 'train_epochs': self.train_epochs,
442
- 'planner_kwargs': planner_kwargs,
443
- 'plan_kwargs': plan_kwargs,
444
- 'verbose': self.verbose,
445
- 'wrapped_bool_actions': self.wrapped_bool_actions,
446
- 'eval_trials': self.eval_trials
447
- }
448
- return objective_fn, kwargs
449
-
450
-
451
- # ===============================================================================
452
- #
453
- # REPLANNING
454
- #
455
- # ===============================================================================
456
- def objective_replan(params, kwargs, key, index):
457
-
458
- # transform hyper-parameters to natural space
459
- param_values = [
460
- pmap(params[name])
461
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
462
- ]
463
-
464
- # unpack hyper-parameters
465
- if kwargs['wrapped_bool_actions']:
466
- std, lr, w, wa, T = param_values
467
- else:
468
- std, lr, w, T = param_values
469
- wa = None
470
- key, subkey = jax.random.split(key)
471
- if kwargs['verbose']:
472
- print(f'[{index}] key={subkey[0]}, '
473
- f'std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
474
-
475
- # initialize planning algorithm
476
- planner = JaxBackpropPlanner(
477
- rddl=deepcopy(kwargs['rddl']),
478
- plan=JaxStraightLinePlan(
479
- initializer=jax.nn.initializers.normal(std),
480
- **kwargs['plan_kwargs']),
481
- rollout_horizon=T,
482
- optimizer_kwargs={'learning_rate': lr},
483
- **kwargs['planner_kwargs'])
484
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
485
- model_params = {name: w for name in planner.compiled.model_params}
486
-
487
- # initialize controller
488
- policy = JaxOnlineController(
489
- planner=planner,
490
- key=subkey,
491
- eval_hyperparams=policy_hparams,
492
- warm_start=kwargs['use_guess_last_epoch'],
493
- epochs=kwargs['train_epochs'],
494
- train_seconds=kwargs['timeout_training'],
495
- model_params=model_params,
496
- policy_hyperparams=policy_hparams,
497
- print_summary=False,
498
- print_progress=False,
499
- tqdm_position=index)
500
-
501
- # initialize env for evaluation (need fresh copy to avoid concurrency)
502
- env = RDDLEnv(domain=kwargs['domain'],
503
- instance=kwargs['instance'],
504
- vectorized=True,
505
- enforce_action_constraints=False)
506
-
507
- # perform training
508
- average_reward = 0.0
509
- for trial in range(kwargs['eval_trials']):
510
- key, subkey = jax.random.split(key)
511
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
512
- if kwargs['verbose']:
513
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
514
- f'reward={total_reward}', flush=True)
515
- average_reward += total_reward / kwargs['eval_trials']
516
- if kwargs['verbose']:
517
- print(f'[{index}] average reward={average_reward}', flush=True)
518
- return average_reward
519
-
520
-
521
- class JaxParameterTuningSLPReplan(JaxParameterTuningSLP):
522
-
523
- def __init__(self,
524
- *args,
525
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
526
- 'std': (-5., 2., power_ten),
527
- 'lr': (-5., 2., power_ten),
528
- 'w': (0., 5., power_ten),
529
- 'wa': (0., 5., power_ten),
530
- 'T': (1, None, int)
531
- },
532
- use_guess_last_epoch: bool=True,
533
- **kwargs) -> None:
534
- '''Creates a new tuning class for straight line planners.
535
-
536
- :param *args: arguments to pass to parent class
537
- :param hyperparams_dict: same as parent class, but here must contain
538
- weight initialization (std), learning rate (lr), model weight (w),
539
- action weight (wa) if wrap_sigmoid and boolean action fluents exist, and
540
- lookahead horizon (T)
541
- :param use_guess_last_epoch: use the trained parameters from previous
542
- decision to warm-start next decision
543
- :param **kwargs: keyword arguments to pass to parent class
544
- '''
545
-
546
- super(JaxParameterTuningSLPReplan, self).__init__(
547
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
548
-
549
- self.use_guess_last_epoch = use_guess_last_epoch
550
-
551
- # set upper range of lookahead horizon to environment horizon
552
- if self.hyperparams_dict['T'][1] is None:
553
- self.hyperparams_dict['T'] = (1, self.env.horizon, int)
554
-
555
- def _pickleable_objective_with_kwargs(self):
556
- objective_fn = objective_replan
557
-
558
- # duplicate planner and plan keyword arguments must be removed
559
- plan_kwargs = self.plan_kwargs.copy()
560
- plan_kwargs.pop('initializer', None)
561
-
562
- planner_kwargs = self.planner_kwargs.copy()
563
- planner_kwargs.pop('rddl', None)
564
- planner_kwargs.pop('plan', None)
565
- planner_kwargs.pop('rollout_horizon', None)
566
- planner_kwargs.pop('optimizer_kwargs', None)
567
-
568
- kwargs = {
569
- 'rddl': self.env.model,
570
- 'domain': self.env.domain_text,
571
- 'instance': self.env.instance_text,
572
- 'hyperparams_dict': self.hyperparams_dict,
573
- 'timeout_training': self.timeout_training,
574
- 'train_epochs': self.train_epochs,
575
- 'planner_kwargs': planner_kwargs,
576
- 'plan_kwargs': plan_kwargs,
577
- 'verbose': self.verbose,
578
- 'wrapped_bool_actions': self.wrapped_bool_actions,
579
- 'eval_trials': self.eval_trials,
580
- 'use_guess_last_epoch': self.use_guess_last_epoch
581
- }
582
- return objective_fn, kwargs
583
-
584
-
585
- # ===============================================================================
586
- #
587
- # DEEP REACTIVE POLICIES
588
- #
589
- # ===============================================================================
590
- def objective_drp(params, kwargs, key, index):
591
-
592
- # transform hyper-parameters to natural space
593
- param_values = [
594
- pmap(params[name])
595
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
596
- ]
597
-
598
- # unpack hyper-parameters
599
- lr, w, layers, neurons = param_values
600
- key, subkey = jax.random.split(key)
601
- if kwargs['verbose']:
602
- print(f'[{index}] key={subkey[0]}, '
603
- f'lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
604
-
605
- # initialize planning algorithm
606
- planner = JaxBackpropPlanner(
607
- rddl=deepcopy(kwargs['rddl']),
608
- plan=JaxDeepReactivePolicy(
609
- topology=[neurons] * layers,
610
- **kwargs['plan_kwargs']),
611
- optimizer_kwargs={'learning_rate': lr},
612
- **kwargs['planner_kwargs'])
613
- policy_hparams = {name: None for name in planner._action_bounds}
614
- model_params = {name: w for name in planner.compiled.model_params}
615
-
616
- # initialize policy
617
- policy = JaxOfflineController(
618
- planner=planner,
619
- key=subkey,
620
- eval_hyperparams=policy_hparams,
621
- train_on_reset=True,
622
- epochs=kwargs['train_epochs'],
623
- train_seconds=kwargs['timeout_training'],
624
- model_params=model_params,
625
- policy_hyperparams=policy_hparams,
626
- print_summary=False,
627
- print_progress=False,
628
- tqdm_position=index)
629
-
630
- # initialize env for evaluation (need fresh copy to avoid concurrency)
631
- env = RDDLEnv(domain=kwargs['domain'],
632
- instance=kwargs['instance'],
633
- vectorized=True,
634
- enforce_action_constraints=False)
635
-
636
- # perform training
637
- average_reward = 0.0
638
- for trial in range(kwargs['eval_trials']):
639
- key, subkey = jax.random.split(key)
640
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
641
- if kwargs['verbose']:
642
- print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
643
- f'reward={total_reward}', flush=True)
644
- average_reward += total_reward / kwargs['eval_trials']
645
- if kwargs['verbose']:
646
- print(f'[{index}] average reward={average_reward}', flush=True)
647
- return average_reward
648
-
649
-
650
- def power_two_int(x):
651
- return 2 ** int(x)
652
-
653
-
654
- class JaxParameterTuningDRP(JaxParameterTuning):
655
-
656
- def __init__(self, *args,
657
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
658
- 'lr': (-7., 2., power_ten),
659
- 'w': (0., 5., power_ten),
660
- 'layers': (1., 3., int),
661
- 'neurons': (2., 9., power_two_int)
662
- },
663
- **kwargs) -> None:
664
- '''Creates a new tuning class for deep reactive policies.
665
-
666
- :param *args: arguments to pass to parent class
667
- :param hyperparams_dict: same as parent class, but here must contain
668
- learning rate (lr), model weight (w), number of hidden layers (layers)
669
- and number of neurons per hidden layer (neurons)
670
- :param **kwargs: keyword arguments to pass to parent class
671
- '''
672
-
673
- super(JaxParameterTuningDRP, self).__init__(
674
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
675
-
676
- def _pickleable_objective_with_kwargs(self):
677
- objective_fn = objective_drp
678
-
679
- # duplicate planner and plan keyword arguments must be removed
680
- plan_kwargs = self.plan_kwargs.copy()
681
- plan_kwargs.pop('topology', None)
682
-
683
- planner_kwargs = self.planner_kwargs.copy()
684
- planner_kwargs.pop('rddl', None)
685
- planner_kwargs.pop('plan', None)
686
- planner_kwargs.pop('optimizer_kwargs', None)
687
-
688
- kwargs = {
689
- 'rddl': self.env.model,
690
- 'domain': self.env.domain_text,
691
- 'instance': self.env.instance_text,
692
- 'hyperparams_dict': self.hyperparams_dict,
693
- 'timeout_training': self.timeout_training,
694
- 'train_epochs': self.train_epochs,
695
- 'planner_kwargs': planner_kwargs,
696
- 'plan_kwargs': plan_kwargs,
697
- 'verbose': self.verbose,
698
- 'eval_trials': self.eval_trials
699
- }
700
- return objective_fn, kwargs