pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +832 -530
  4. pyRDDLGym_jax/core/planner.py +422 -474
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +390 -584
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -29
  35. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -1,82 +1,94 @@
1
- from copy import deepcopy
2
1
  import csv
3
2
  import datetime
4
- from multiprocessing import get_context
3
+ import threading
4
+ import multiprocessing
5
5
  import os
6
6
  import time
7
- from typing import Any, Callable, Dict, Optional, Tuple
7
+ from typing import Any, Callable, Dict, Iterable, Optional, Tuple
8
8
  import warnings
9
9
  warnings.filterwarnings("ignore")
10
10
 
11
+ from sklearn.gaussian_process.kernels import Matern, ConstantKernel
11
12
  from bayes_opt import BayesianOptimization
12
- from bayes_opt.util import UtilityFunction
13
+ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
13
14
  import jax
14
15
  import numpy as np
15
16
 
16
- from pyRDDLGym.core.debug.exception import raise_warning
17
17
  from pyRDDLGym.core.env import RDDLEnv
18
18
 
19
19
  from pyRDDLGym_jax.core.planner import (
20
20
  JaxBackpropPlanner,
21
- JaxStraightLinePlan,
22
- JaxDeepReactivePolicy,
23
21
  JaxOfflineController,
24
- JaxOnlineController
22
+ JaxOnlineController,
23
+ load_config_from_string
25
24
  )
26
25
 
26
+ # try to load the dash board
27
+ try:
28
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
29
+ except Exception:
30
+ raise_warning('Failed to load the dashboard visualization tool: '
31
+ 'please make sure you have installed the required packages.',
32
+ 'red')
33
+ traceback.print_exc()
34
+ JaxPlannerDashboard = None
35
+
36
+
37
+ class Hyperparameter:
38
+ '''A generic hyper-parameter of the planner that can be tuned.'''
39
+
40
+ def __init__(self, tag: str, lower_bound: float, upper_bound: float,
41
+ search_to_config_map: Callable) -> None:
42
+ self.tag = tag
43
+ self.lower_bound = lower_bound
44
+ self.upper_bound = upper_bound
45
+ self.search_to_config_map = search_to_config_map
46
+
47
+ def __str__(self) -> str:
48
+ return (f'{self.search_to_config_map.__name__} '
49
+ f': [{self.lower_bound}, {self.upper_bound}] -> {self.tag}')
50
+
51
+
27
52
  Kwargs = Dict[str, Any]
53
+ ParameterValues = Dict[str, Any]
54
+ Hyperparameters = Iterable[Hyperparameter]
55
+
56
+ COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
28
57
 
29
58
 
30
- # ===============================================================================
31
- #
32
- # GENERIC TUNING MODULE
33
- #
34
- # Currently contains three implementations:
35
- # 1. straight line plan
36
- # 2. re-planning
37
- # 3. deep reactive policies
38
- #
39
- # ===============================================================================
40
59
  class JaxParameterTuning:
41
60
  '''A general-purpose class for tuning a Jax planner.'''
42
61
 
43
62
  def __init__(self, env: RDDLEnv,
44
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]],
45
- train_epochs: int,
46
- timeout_training: float,
47
- timeout_tuning: float=np.inf,
63
+ config_template: str,
64
+ hyperparams: Hyperparameters,
65
+ online: bool,
48
66
  eval_trials: int=5,
49
67
  verbose: bool=True,
50
- planner_kwargs: Optional[Kwargs]=None,
51
- plan_kwargs: Optional[Kwargs]=None,
68
+ timeout_tuning: float=np.inf,
52
69
  pool_context: str='spawn',
53
70
  num_workers: int=1,
54
71
  poll_frequency: float=0.2,
55
72
  gp_iters: int=25,
56
- acquisition: Optional[UtilityFunction]=None,
73
+ acquisition: Optional[AcquisitionFunction]=None,
57
74
  gp_init_kwargs: Optional[Kwargs]=None,
58
75
  gp_params: Optional[Kwargs]=None) -> None:
59
76
  '''Creates a new instance for tuning hyper-parameters for Jax planners
60
77
  on the given RDDL domain and instance.
61
78
 
62
79
  :param env: the RDDLEnv describing the MDP to optimize
63
- :param hyperparams_dict: dictionary mapping name of each hyperparameter
64
- to a triple, where the first two elements are lower/upper bounds on the
65
- parameter value, and the last is a callable mapping the parameter to its
66
- RDDL equivalent
67
- :param train_epochs: the maximum number of iterations of SGD per
68
- step or trial
69
- :param timeout_training: the maximum amount of time to spend training per
70
- trial/decision step (in seconds)
80
+ :param config_template: base configuration file content to tune: regex
81
+ matches are specified directly in the config and map to keys in the
82
+ hyperparams_dict field
83
+ :param hyperparams: list of hyper-parameters to regex replace in the
84
+ config template during tuning
85
+ :param online: whether the planner is optimized online or offline
71
86
  :param timeout_tuning: the maximum amount of time to spend tuning
72
87
  hyperparameters in general (in seconds)
73
88
  :param eval_trials: how many trials to perform independent training
74
89
  in order to estimate the return for each set of hyper-parameters
75
90
  :param verbose: whether to print intermediate results of tuning
76
- :param planner_kwargs: additional arguments to feed to the planner
77
- :param plan_kwargs: additional arguments to feed to the plan/policy
78
- :param pool_context: context for multiprocessing pool (defaults to
79
- "spawn")
91
+ :param pool_context: context for multiprocessing pool (default "spawn")
80
92
  :param num_workers: how many points to evaluate in parallel
81
93
  :param poll_frequency: how often (in seconds) to poll for completed
82
94
  jobs, necessary if num_workers > 1
@@ -86,21 +98,20 @@ class JaxParameterTuning:
86
98
  during initialization
87
99
  :param gp_params: additional parameters to feed to Bayesian optimizer
88
100
  after initialization optimization
89
- '''
90
-
101
+ '''
102
+ # objective parameters
91
103
  self.env = env
104
+ self.config_template = config_template
105
+ hyperparams_dict = {hyper_param.tag: hyper_param
106
+ for hyper_param in hyperparams
107
+ if hyper_param.tag in config_template}
92
108
  self.hyperparams_dict = hyperparams_dict
93
- self.train_epochs = train_epochs
94
- self.timeout_training = timeout_training
95
- self.timeout_tuning = timeout_tuning
109
+ self.online = online
96
110
  self.eval_trials = eval_trials
97
111
  self.verbose = verbose
98
- if planner_kwargs is None:
99
- planner_kwargs = {}
100
- self.planner_kwargs = planner_kwargs
101
- if plan_kwargs is None:
102
- plan_kwargs = {}
103
- self.plan_kwargs = plan_kwargs
112
+
113
+ # Bayesian parameters
114
+ self.timeout_tuning = timeout_tuning
104
115
  self.pool_context = pool_context
105
116
  self.num_workers = num_workers
106
117
  self.poll_frequency = poll_frequency
@@ -109,21 +120,33 @@ class JaxParameterTuning:
109
120
  gp_init_kwargs = {}
110
121
  self.gp_init_kwargs = gp_init_kwargs
111
122
  if gp_params is None:
112
- gp_params = {'n_restarts_optimizer': 10}
123
+ gp_params = {'n_restarts_optimizer': 25,
124
+ 'kernel': self.make_default_kernel()}
113
125
  self.gp_params = gp_params
114
-
115
- # create acquisition function
116
- self.acq_args = None
117
126
  if acquisition is None:
118
127
  num_samples = self.gp_iters * self.num_workers
119
- acquisition, self.acq_args = JaxParameterTuning._annealing_utility(num_samples)
128
+ acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
120
129
  self.acquisition = acquisition
121
130
 
131
+ @staticmethod
132
+ def make_default_kernel():
133
+ weight1 = ConstantKernel(1.0, (0.01, 100.0))
134
+ weight2 = ConstantKernel(1.0, (0.01, 100.0))
135
+ weight3 = ConstantKernel(1.0, (0.01, 100.0))
136
+ kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
137
+ kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
138
+ kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
139
+ return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
140
+
122
141
  def summarize_hyperparameters(self) -> None:
142
+ hyper_params_table = []
143
+ for (_, param) in self.hyperparams_dict.items():
144
+ hyper_params_table.append(f' {str(param)}')
145
+ hyper_params_table = '\n'.join(hyper_params_table)
123
146
  print(f'hyperparameter optimizer parameters:\n'
124
- f' tuned_hyper_parameters ={self.hyperparams_dict}\n'
147
+ f' tuned_hyper_parameters =\n{hyper_params_table}\n'
125
148
  f' initialization_args ={self.gp_init_kwargs}\n'
126
- f' additional_args ={self.gp_params}\n'
149
+ f' gp_params ={self.gp_params}\n'
127
150
  f' tuning_iterations ={self.gp_iters}\n'
128
151
  f' tuning_timeout ={self.timeout_tuning}\n'
129
152
  f' tuning_batch_size ={self.num_workers}\n'
@@ -131,154 +154,348 @@ class JaxParameterTuning:
131
154
  f' mp_pool_poll_frequency ={self.poll_frequency}\n'
132
155
  f'meta-objective parameters:\n'
133
156
  f' planning_trials_per_iter ={self.eval_trials}\n'
134
- f' planning_iters_per_trial ={self.train_epochs}\n'
135
- f' planning_timeout_per_trial={self.timeout_training}\n'
136
- f' acquisition_fn ={type(self.acquisition).__name__}')
137
- if self.acq_args is not None:
138
- print(f'using default acquisition function:\n'
139
- f' utility_kind ={self.acq_args[0]}\n'
140
- f' initial_kappa={self.acq_args[1]}\n'
141
- f' kappa_decay ={self.acq_args[2]}')
157
+ f' acquisition_fn ={self.acquisition}')
142
158
 
143
159
  @staticmethod
144
- def _annealing_utility(n_samples, n_delay_samples=0, kappa1=10.0, kappa2=1.0):
145
- kappa_decay = (kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples))
146
- utility_fn = UtilityFunction(
147
- kind='ucb',
160
+ def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
161
+ kappa1: float=10.0, kappa2: float=1.0) -> UpperConfidenceBound:
162
+ acq_fn = UpperConfidenceBound(
148
163
  kappa=kappa1,
149
- kappa_decay=kappa_decay,
150
- kappa_decay_delay=n_delay_samples)
151
- utility_args = ['ucb', kappa1, kappa_decay]
152
- return utility_fn, utility_args
164
+ exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
165
+ exploration_decay_delay=n_delay_samples
166
+ )
167
+ return acq_fn
168
+
169
+ @staticmethod
170
+ def search_to_config_params(hyper_params: Hyperparameters,
171
+ params: ParameterValues) -> ParameterValues:
172
+ config_params = {
173
+ tag: param.search_to_config_map(params[tag])
174
+ for (tag, param) in hyper_params.items()
175
+ }
176
+ return config_params
177
+
178
+ @staticmethod
179
+ def config_from_template(config_template: str,
180
+ config_params: ParameterValues) -> str:
181
+ config_string = config_template
182
+ for (tag, param_value) in config_params.items():
183
+ config_string = config_string.replace(tag, str(param_value))
184
+ return config_string
153
185
 
154
- def _pickleable_objective_with_kwargs(self):
155
- raise NotImplementedError
186
+ @property
187
+ def best_config(self) -> str:
188
+ return self.config_from_template(self.config_template, self.best_params)
189
+
190
+ @staticmethod
191
+ def queue_listener(queue, dashboard):
192
+ while True:
193
+ args = queue.get()
194
+ if args is None:
195
+ break
196
+ elif len(args) == 2:
197
+ dashboard.update_experiment(*args)
198
+ else:
199
+ dashboard.register_experiment(*args)
156
200
 
157
201
  @staticmethod
158
- def _wrapped_evaluate(index, params, key, func, kwargs):
159
- target = func(params=params, kwargs=kwargs, key=key, index=index)
202
+ def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
203
+ verbose, viz, queue):
204
+ average_reward = 0.0
205
+ for trial in range(num_trials):
206
+ key, subkey = jax.random.split(key)
207
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
208
+ if queue is not None:
209
+ queue.put((
210
+ experiment_id,
211
+ JaxPlannerDashboard.get_planner_info(planner),
212
+ subkey[0],
213
+ viz
214
+ ))
215
+
216
+ # train the policy
217
+ callback = None
218
+ for callback in planner.optimize_generator(key=subkey, **train_args):
219
+ if queue is not None and queue.empty():
220
+ queue.put((experiment_id, callback))
221
+ best_params = None if callback is None else callback['best_params']
222
+
223
+ # evaluate the policy in the real environment
224
+ policy = JaxOfflineController(
225
+ planner=planner, key=subkey, tqdm_position=index,
226
+ params=best_params, train_on_reset=False)
227
+ total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
228
+
229
+ # update average reward
230
+ if verbose:
231
+ iters = None if callback is None else callback['iteration']
232
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
233
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
234
+ average_reward += total_reward / num_trials
235
+
236
+ if verbose:
237
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
238
+ return average_reward
239
+
240
+ @staticmethod
241
+ def online_trials(env, planner, train_args, key, iteration, index, num_trials,
242
+ verbose, viz, queue):
243
+ average_reward = 0.0
244
+ for trial in range(num_trials):
245
+ key, subkey = jax.random.split(key)
246
+ experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
247
+ if queue is not None:
248
+ queue.put((
249
+ experiment_id,
250
+ JaxPlannerDashboard.get_planner_info(planner),
251
+ subkey[0],
252
+ viz
253
+ ))
254
+
255
+ # initialize the online policy
256
+ policy = JaxOnlineController(
257
+ planner=planner, key=subkey, tqdm_position=index, **train_args)
258
+
259
+ # evaluate the policy in the real environment
260
+ total_reward = 0.0
261
+ callback = None
262
+ state, _ = env.reset(seed=np.array(subkey)[0])
263
+ elapsed_time = 0.0
264
+ for step in range(env.horizon):
265
+ action = policy.sample_action(state)
266
+ next_state, reward, terminated, truncated, _ = env.step(action)
267
+ total_reward += reward
268
+ done = terminated or truncated
269
+ state = next_state
270
+ callback = policy.callback
271
+ elapsed_time += callback['elapsed_time']
272
+ callback['iteration'] = step
273
+ callback['progress'] = int(100 * (step + 1.) / env.horizon)
274
+ callback['elapsed_time'] = elapsed_time
275
+ if queue is not None and queue.empty():
276
+ queue.put((experiment_id, callback))
277
+ if done:
278
+ break
279
+
280
+ # update average reward
281
+ if verbose:
282
+ iters = None if callback is None else callback['iteration']
283
+ print(f' [{index}] trial {trial + 1}, key={subkey[0]}, '
284
+ f'reward={total_reward:.6f}, iters={iters}', flush=True)
285
+ average_reward += total_reward / num_trials
286
+
287
+ if verbose:
288
+ print(f'[{index}] average reward={average_reward:.6f}', flush=True)
289
+ return average_reward
290
+
291
+ @staticmethod
292
+ def objective_function(params: ParameterValues,
293
+ key: jax.random.PRNGKey,
294
+ index: int,
295
+ iteration: int,
296
+ kwargs: Kwargs,
297
+ queue: object) -> Tuple[ParameterValues, float, int, int]:
298
+ '''A pickleable objective function to evaluate a single hyper-parameter
299
+ configuration.'''
300
+
301
+ hyperparams_dict = kwargs['hyperparams_dict']
302
+ config_template = kwargs['config_template']
303
+ online = kwargs['online']
304
+ domain = kwargs['domain']
305
+ instance = kwargs['instance']
306
+ num_trials = kwargs['eval_trials']
307
+ viz = kwargs['viz']
308
+ verbose = kwargs['verbose']
309
+
310
+ # config string substitution and parsing
311
+ config_params = JaxParameterTuning.search_to_config_params(hyperparams_dict, params)
312
+ if verbose:
313
+ config_param_str = ', '.join(
314
+ f'{k}={v}' for (k, v) in config_params.items())
315
+ print(f'[{index}] key={key[0]}, {config_param_str}', flush=True)
316
+ config_string = JaxParameterTuning.config_from_template(config_template, config_params)
317
+ planner_args, _, train_args = load_config_from_string(config_string)
318
+
319
+ # remove keywords that should not be in the tuner
320
+ train_args.pop('dashboard', None)
321
+
322
+ # initialize env for evaluation (need fresh copy to avoid concurrency)
323
+ env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
324
+
325
+ # run planning algorithm
326
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
327
+ if online:
328
+ average_reward = JaxParameterTuning.online_trials(
329
+ env, planner, train_args, key, iteration, index, num_trials,
330
+ verbose, viz, queue
331
+ )
332
+ else:
333
+ average_reward = JaxParameterTuning.offline_trials(
334
+ env, planner, train_args, key, iteration, index,
335
+ num_trials, verbose, viz, queue
336
+ )
337
+
160
338
  pid = os.getpid()
161
- return index, pid, params, target
162
-
163
- def tune(self, key: jax.random.PRNGKey,
164
- filename: str,
165
- save_plot: bool=False) -> Dict[str, Any]:
339
+ return params, average_reward, index, pid
340
+
341
+ def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
342
+ '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
343
+ print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
344
+
345
+ def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
166
346
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
347
+
167
348
  self.summarize_hyperparameters()
168
349
 
169
- start_time = time.time()
350
+ # clear and prepare output file
351
+ with open(log_file, 'w', newline='') as file:
352
+ writer = csv.writer(file)
353
+ writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
170
354
 
171
- # objective function
172
- objective = self._pickleable_objective_with_kwargs()
173
- evaluate = JaxParameterTuning._wrapped_evaluate
355
+ # create a dash-board for visualizing experiment runs
356
+ if show_dashboard:
357
+ dashboard = JaxPlannerDashboard()
358
+ dashboard.launch()
174
359
 
360
+ # objective function auxiliary data
361
+ obj_kwargs = {
362
+ 'hyperparams_dict': self.hyperparams_dict,
363
+ 'config_template': self.config_template,
364
+ 'online': self.online,
365
+ 'domain': self.env.domain_text,
366
+ 'instance': self.env.instance_text,
367
+ 'eval_trials': self.eval_trials,
368
+ 'viz': self.env._visualizer,
369
+ 'verbose': self.verbose
370
+ }
371
+
175
372
  # create optimizer
176
373
  hyperparams_bounds = {
177
- name: hparam[:2]
178
- for (name, hparam) in self.hyperparams_dict.items()
374
+ tag: (param.lower_bound, param.upper_bound)
375
+ for (tag, param) in self.hyperparams_dict.items()
179
376
  }
180
377
  optimizer = BayesianOptimization(
181
- f=None, # probe() is not called
378
+ f=None,
379
+ acquisition_function=self.acquisition,
182
380
  pbounds=hyperparams_bounds,
183
381
  allow_duplicate_points=True, # to avoid crash
184
382
  random_state=np.random.RandomState(key),
185
383
  **self.gp_init_kwargs
186
384
  )
187
385
  optimizer.set_gp_params(**self.gp_params)
188
- utility = self.acquisition
189
386
 
190
387
  # suggest initial parameters to evaluate
191
388
  num_workers = self.num_workers
192
- suggested, kappas = [], []
389
+ suggested_params, acq_params = [], []
193
390
  for _ in range(num_workers):
194
- utility.update_params()
195
- probe = optimizer.suggest(utility)
196
- suggested.append(probe)
197
- kappas.append(utility.kappa)
391
+ probe = optimizer.suggest()
392
+ suggested_params.append(probe)
393
+ acq_params.append(vars(optimizer.acquisition_function))
198
394
 
199
- # clear and prepare output file
200
- filename = self._filename(filename, 'csv')
201
- with open(filename, 'w', newline='') as file:
202
- writer = csv.writer(file)
203
- writer.writerow(
204
- ['pid', 'worker', 'iteration', 'target', 'best_target', 'kappa'] + \
205
- list(hyperparams_bounds.keys())
206
- )
207
-
208
- # start multiprocess evaluation
209
- worker_ids = list(range(num_workers))
210
- best_params, best_target = None, -np.inf
211
-
212
- for it in range(self.gp_iters):
395
+ with multiprocessing.Manager() as manager:
213
396
 
214
- # check if there is enough time left for another iteration
215
- elapsed = time.time() - start_time
216
- if elapsed >= self.timeout_tuning:
217
- print(f'global time limit reached at iteration {it}, aborting')
218
- break
397
+ # queue and parallel thread for handing render events
398
+ if show_dashboard:
399
+ queue = manager.Queue()
400
+ dashboard_thread = threading.Thread(
401
+ target=JaxParameterTuning.queue_listener,
402
+ args=(queue, dashboard)
403
+ )
404
+ dashboard_thread.start()
405
+ else:
406
+ queue = None
219
407
 
220
- # continue with next iteration
221
- print('\n' + '*' * 25 +
222
- '\n' + f'[{datetime.timedelta(seconds=elapsed)}] ' +
223
- f'starting iteration {it}' +
224
- '\n' + '*' * 25)
225
- key, *subkeys = jax.random.split(key, num=num_workers + 1)
226
- rows = [None] * num_workers
408
+ # start multiprocess evaluation
409
+ worker_ids = list(range(num_workers))
410
+ best_params, best_target = None, -np.inf
411
+ key = jax.random.PRNGKey(key)
412
+ start_time = time.time()
227
413
 
228
- # create worker pool: note each iteration must wait for all workers
229
- # to finish before moving to the next
230
- with get_context(self.pool_context).Pool(processes=num_workers) as pool:
414
+ for it in range(self.gp_iters):
231
415
 
232
- # assign jobs to worker pool
233
- # - each trains on suggested parameters from the last iteration
234
- # - this way, since each job finishes asynchronously, these
235
- # parameters usually differ across jobs
236
- results = [
237
- pool.apply_async(evaluate, worker_args + objective)
238
- for worker_args in zip(worker_ids, suggested, subkeys)
239
- ]
240
-
241
- # wait for all workers to complete
242
- while results:
243
- time.sleep(self.poll_frequency)
244
-
245
- # determine which jobs have completed
246
- jobs_done = []
247
- for (i, candidate) in enumerate(results):
248
- if candidate.ready():
249
- jobs_done.append(i)
416
+ # check if there is enough time left for another iteration
417
+ elapsed = time.time() - start_time
418
+ if elapsed >= self.timeout_tuning:
419
+ print(f'global time limit reached at iteration {it}, aborting')
420
+ break
421
+
422
+ # continue with next iteration
423
+ print('\n' + '*' * 80 +
424
+ f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
425
+ f'starting iteration {it + 1}' +
426
+ '\n' + '*' * 80)
427
+ key, *subkeys = jax.random.split(key, num=num_workers + 1)
428
+ rows = [None] * num_workers
429
+ old_best_target = best_target
430
+
431
+ # create worker pool: note each iteration must wait for all workers
432
+ # to finish before moving to the next
433
+ with multiprocessing.get_context(
434
+ self.pool_context).Pool(processes=num_workers) as pool:
250
435
 
251
- # get result from completed jobs
252
- for i in jobs_done[::-1]:
253
-
254
- # extract and register the new evaluation
255
- index, pid, params, target = results.pop(i).get()
256
- optimizer.register(params, target)
257
-
258
- # update acquisition function and suggest a new point
259
- utility.update_params()
260
- suggested[index] = optimizer.suggest(utility)
261
- old_kappa = kappas[index]
262
- kappas[index] = utility.kappa
263
-
264
- # transform suggestion back to natural space
265
- rddl_params = {
266
- name: pf(params[name])
267
- for (name, (*_, pf)) in self.hyperparams_dict.items()
268
- }
269
-
270
- # update the best suggestion so far
271
- if target > best_target:
272
- best_params, best_target = rddl_params, target
436
+ # assign jobs to worker pool
437
+ results = [
438
+ pool.apply_async(JaxParameterTuning.objective_function,
439
+ obj_args + (it, obj_kwargs, queue))
440
+ for obj_args in zip(suggested_params, subkeys, worker_ids)
441
+ ]
442
+
443
+ # wait for all workers to complete
444
+ while results:
445
+ time.sleep(self.poll_frequency)
273
446
 
274
- # write progress to file in real time
275
- rows[index] = [pid, index, it, target, best_target, old_kappa] + \
276
- list(rddl_params.values())
447
+ # determine which jobs have completed
448
+ jobs_done = []
449
+ for (i, candidate) in enumerate(results):
450
+ if candidate.ready():
451
+ jobs_done.append(i)
277
452
 
278
- # write results of all processes in current iteration to file
279
- with open(filename, 'a', newline='') as file:
280
- writer = csv.writer(file)
281
- writer.writerows(rows)
453
+ # get result from completed jobs
454
+ for i in jobs_done[::-1]:
455
+
456
+ # extract and register the new evaluation
457
+ params, target, index, pid = results.pop(i).get()
458
+ optimizer.register(params, target)
459
+ optimizer._gp.fit(
460
+ optimizer.space.params, optimizer.space.target)
461
+
462
+ # update acquisition function and suggest a new point
463
+ suggested_params[index] = optimizer.suggest()
464
+ old_acq_params = acq_params[index]
465
+ acq_params[index] = vars(optimizer.acquisition_function)
466
+
467
+ # transform suggestion back to natural space
468
+ config_params = JaxParameterTuning.search_to_config_params(
469
+ self.hyperparams_dict, params)
470
+
471
+ # update the best suggestion so far
472
+ if target > best_target:
473
+ best_params, best_target = config_params, target
474
+
475
+ rows[index] = [pid, index, it, target,
476
+ best_target, old_acq_params] + \
477
+ list(config_params.values())
478
+
479
+ # print best parameter if found
480
+ if best_target > old_best_target:
481
+ print(f'* found new best average reward {best_target:.6f}')
482
+
483
+ # tune the optimizer here
484
+ self.tune_optimizer(optimizer)
485
+
486
+ # write results of all processes in current iteration to file
487
+ with open(log_file, 'a', newline='') as file:
488
+ writer = csv.writer(file)
489
+ writer.writerows(rows)
490
+
491
+ # update the dashboard tuning
492
+ if show_dashboard:
493
+ dashboard.update_tuning(optimizer, hyperparams_bounds)
494
+
495
+ # stop the queue listener thread
496
+ if show_dashboard:
497
+ queue.put(None)
498
+ dashboard_thread.join()
282
499
 
283
500
  # print summary of results
284
501
  elapsed = time.time() - start_time
@@ -286,420 +503,9 @@ class JaxParameterTuning:
286
503
  f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
287
504
  f' iterations ={it + 1}\n'
288
505
  f' best_hyper_parameters={best_params}\n'
289
- f' best_meta_objective ={best_target}\n')
506
+ f' best_meta_objective ={best_target}\n')
290
507
 
291
- if save_plot:
292
- self._save_plot(filename)
508
+ self.best_params = best_params
509
+ self.optimizer = optimizer
510
+ self.log_file = log_file
293
511
  return best_params
294
-
295
- def _filename(self, name, ext):
296
- domain_name = ''.join(c for c in self.env.model.domain_name
297
- if c.isalnum() or c == '_')
298
- instance_name = ''.join(c for c in self.env.model.instance_name
299
- if c.isalnum() or c == '_')
300
- filename = f'{name}_{domain_name}_{instance_name}.{ext}'
301
- return filename
302
-
303
- def _save_plot(self, filename):
304
- try:
305
- import matplotlib.pyplot as plt
306
- from sklearn.manifold import MDS
307
- except Exception as e:
308
- raise_warning(f'failed to import packages matplotlib or sklearn, '
309
- f'aborting plot of search space\n{e}', 'red')
310
- else:
311
- data = np.loadtxt(filename, delimiter=',', dtype=object)
312
- data, target = data[1:, 3:], data[1:, 2]
313
- data = data.astype(np.float64)
314
- target = target.astype(np.float64)
315
- target = (target - np.min(target)) / (np.max(target) - np.min(target))
316
- embedding = MDS(n_components=2, normalized_stress='auto')
317
- data1 = embedding.fit_transform(data)
318
- sc = plt.scatter(data1[:, 0], data1[:, 1], c=target, s=4.,
319
- cmap='seismic', edgecolor='gray',
320
- linewidth=0.01, alpha=0.4)
321
- plt.colorbar(sc)
322
- plt.savefig(self._filename('gp_points', 'pdf'))
323
- plt.clf()
324
- plt.close()
325
-
326
-
327
- # ===============================================================================
328
- #
329
- # STRAIGHT LINE PLANNING
330
- #
331
- # ===============================================================================
332
- def objective_slp(params, kwargs, key, index):
333
-
334
- # transform hyper-parameters to natural space
335
- param_values = [
336
- pmap(params[name])
337
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
338
- ]
339
-
340
- # unpack hyper-parameters
341
- if kwargs['wrapped_bool_actions']:
342
- std, lr, w, wa = param_values
343
- else:
344
- std, lr, w = param_values
345
- wa = None
346
- if kwargs['verbose']:
347
- print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
348
-
349
- # initialize planning algorithm
350
- planner = JaxBackpropPlanner(
351
- rddl=deepcopy(kwargs['rddl']),
352
- plan=JaxStraightLinePlan(
353
- initializer=jax.nn.initializers.normal(std),
354
- **kwargs['plan_kwargs']),
355
- optimizer_kwargs={'learning_rate': lr},
356
- **kwargs['planner_kwargs'])
357
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
358
- model_params = {name: w for name in planner.compiled.model_params}
359
-
360
- # initialize policy
361
- key, subkey = jax.random.split(key)
362
- policy = JaxOfflineController(
363
- planner=planner,
364
- key=subkey,
365
- eval_hyperparams=policy_hparams,
366
- train_on_reset=True,
367
- epochs=kwargs['train_epochs'],
368
- train_seconds=kwargs['timeout_training'],
369
- model_params=model_params,
370
- policy_hyperparams=policy_hparams,
371
- print_summary=False,
372
- print_progress=False,
373
- tqdm_position=index)
374
-
375
- # initialize env for evaluation (need fresh copy to avoid concurrency)
376
- env = RDDLEnv(domain=kwargs['domain'],
377
- instance=kwargs['instance'],
378
- vectorized=True,
379
- enforce_action_constraints=False)
380
-
381
- # perform training
382
- average_reward = 0.0
383
- for trial in range(kwargs['eval_trials']):
384
- key, subkey = jax.random.split(key)
385
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
386
- if kwargs['verbose']:
387
- print(f' [{index}] trial {trial + 1} key={subkey}, '
388
- f'reward={total_reward}', flush=True)
389
- average_reward += total_reward / kwargs['eval_trials']
390
- if kwargs['verbose']:
391
- print(f'[{index}] average reward={average_reward}', flush=True)
392
- return average_reward
393
-
394
-
395
- def power_ten(x):
396
- return 10.0 ** x
397
-
398
-
399
- class JaxParameterTuningSLP(JaxParameterTuning):
400
-
401
- def __init__(self, *args,
402
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
403
- 'std': (-5., 2., power_ten),
404
- 'lr': (-5., 2., power_ten),
405
- 'w': (0., 5., power_ten),
406
- 'wa': (0., 5., power_ten)
407
- },
408
- **kwargs) -> None:
409
- '''Creates a new tuning class for straight line planners.
410
-
411
- :param *args: arguments to pass to parent class
412
- :param hyperparams_dict: same as parent class, but here must contain
413
- weight initialization (std), learning rate (lr), model weight (w), and
414
- action weight (wa) if wrap_sigmoid and boolean action fluents exist
415
- :param **kwargs: keyword arguments to pass to parent class
416
- '''
417
-
418
- super(JaxParameterTuningSLP, self).__init__(
419
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
420
-
421
- # action parameters required if wrap_sigmoid and boolean action exists
422
- self.wrapped_bool_actions = []
423
- if self.plan_kwargs.get('wrap_sigmoid', True):
424
- for var in self.env.model.action_fluents:
425
- if self.env.model.variable_ranges[var] == 'bool':
426
- self.wrapped_bool_actions.append(var)
427
- if not self.wrapped_bool_actions:
428
- self.hyperparams_dict.pop('wa', None)
429
-
430
- def _pickleable_objective_with_kwargs(self):
431
- objective_fn = objective_slp
432
-
433
- # duplicate planner and plan keyword arguments must be removed
434
- plan_kwargs = self.plan_kwargs.copy()
435
- plan_kwargs.pop('initializer', None)
436
-
437
- planner_kwargs = self.planner_kwargs.copy()
438
- planner_kwargs.pop('rddl', None)
439
- planner_kwargs.pop('plan', None)
440
- planner_kwargs.pop('optimizer_kwargs', None)
441
-
442
- kwargs = {
443
- 'rddl': self.env.model,
444
- 'domain': self.env.domain_text,
445
- 'instance': self.env.instance_text,
446
- 'hyperparams_dict': self.hyperparams_dict,
447
- 'timeout_training': self.timeout_training,
448
- 'train_epochs': self.train_epochs,
449
- 'planner_kwargs': planner_kwargs,
450
- 'plan_kwargs': plan_kwargs,
451
- 'verbose': self.verbose,
452
- 'wrapped_bool_actions': self.wrapped_bool_actions,
453
- 'eval_trials': self.eval_trials
454
- }
455
- return objective_fn, kwargs
456
-
457
-
458
- # ===============================================================================
459
- #
460
- # REPLANNING
461
- #
462
- # ===============================================================================
463
- def objective_replan(params, kwargs, key, index):
464
-
465
- # transform hyper-parameters to natural space
466
- param_values = [
467
- pmap(params[name])
468
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
469
- ]
470
-
471
- # unpack hyper-parameters
472
- if kwargs['wrapped_bool_actions']:
473
- std, lr, w, wa, T = param_values
474
- else:
475
- std, lr, w, T = param_values
476
- wa = None
477
- if kwargs['verbose']:
478
- print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
479
-
480
- # initialize planning algorithm
481
- planner = JaxBackpropPlanner(
482
- rddl=deepcopy(kwargs['rddl']),
483
- plan=JaxStraightLinePlan(
484
- initializer=jax.nn.initializers.normal(std),
485
- **kwargs['plan_kwargs']),
486
- rollout_horizon=T,
487
- optimizer_kwargs={'learning_rate': lr},
488
- **kwargs['planner_kwargs'])
489
- policy_hparams = {name: wa for name in kwargs['wrapped_bool_actions']}
490
- model_params = {name: w for name in planner.compiled.model_params}
491
-
492
- # initialize controller
493
- key, subkey = jax.random.split(key)
494
- policy = JaxOnlineController(
495
- planner=planner,
496
- key=subkey,
497
- eval_hyperparams=policy_hparams,
498
- warm_start=kwargs['use_guess_last_epoch'],
499
- epochs=kwargs['train_epochs'],
500
- train_seconds=kwargs['timeout_training'],
501
- model_params=model_params,
502
- policy_hyperparams=policy_hparams,
503
- print_summary=False,
504
- print_progress=False,
505
- tqdm_position=index)
506
-
507
- # initialize env for evaluation (need fresh copy to avoid concurrency)
508
- env = RDDLEnv(domain=kwargs['domain'],
509
- instance=kwargs['instance'],
510
- vectorized=True,
511
- enforce_action_constraints=False)
512
-
513
- # perform training
514
- average_reward = 0.0
515
- for trial in range(kwargs['eval_trials']):
516
- key, subkey = jax.random.split(key)
517
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
518
- if kwargs['verbose']:
519
- print(f' [{index}] trial {trial + 1} key={subkey}, '
520
- f'reward={total_reward}', flush=True)
521
- average_reward += total_reward / kwargs['eval_trials']
522
- if kwargs['verbose']:
523
- print(f'[{index}] average reward={average_reward}', flush=True)
524
- return average_reward
525
-
526
-
527
- class JaxParameterTuningSLPReplan(JaxParameterTuningSLP):
528
-
529
- def __init__(self,
530
- *args,
531
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
532
- 'std': (-5., 2., power_ten),
533
- 'lr': (-5., 2., power_ten),
534
- 'w': (0., 5., power_ten),
535
- 'wa': (0., 5., power_ten),
536
- 'T': (1, None, int)
537
- },
538
- use_guess_last_epoch: bool=True,
539
- **kwargs) -> None:
540
- '''Creates a new tuning class for straight line planners.
541
-
542
- :param *args: arguments to pass to parent class
543
- :param hyperparams_dict: same as parent class, but here must contain
544
- weight initialization (std), learning rate (lr), model weight (w),
545
- action weight (wa) if wrap_sigmoid and boolean action fluents exist, and
546
- lookahead horizon (T)
547
- :param use_guess_last_epoch: use the trained parameters from previous
548
- decision to warm-start next decision
549
- :param **kwargs: keyword arguments to pass to parent class
550
- '''
551
-
552
- super(JaxParameterTuningSLPReplan, self).__init__(
553
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
554
-
555
- self.use_guess_last_epoch = use_guess_last_epoch
556
-
557
- # set upper range of lookahead horizon to environment horizon
558
- if self.hyperparams_dict['T'][1] is None:
559
- self.hyperparams_dict['T'] = (1, self.env.horizon, int)
560
-
561
- def _pickleable_objective_with_kwargs(self):
562
- objective_fn = objective_replan
563
-
564
- # duplicate planner and plan keyword arguments must be removed
565
- plan_kwargs = self.plan_kwargs.copy()
566
- plan_kwargs.pop('initializer', None)
567
-
568
- planner_kwargs = self.planner_kwargs.copy()
569
- planner_kwargs.pop('rddl', None)
570
- planner_kwargs.pop('plan', None)
571
- planner_kwargs.pop('rollout_horizon', None)
572
- planner_kwargs.pop('optimizer_kwargs', None)
573
-
574
- kwargs = {
575
- 'rddl': self.env.model,
576
- 'domain': self.env.domain_text,
577
- 'instance': self.env.instance_text,
578
- 'hyperparams_dict': self.hyperparams_dict,
579
- 'timeout_training': self.timeout_training,
580
- 'train_epochs': self.train_epochs,
581
- 'planner_kwargs': planner_kwargs,
582
- 'plan_kwargs': plan_kwargs,
583
- 'verbose': self.verbose,
584
- 'wrapped_bool_actions': self.wrapped_bool_actions,
585
- 'eval_trials': self.eval_trials,
586
- 'use_guess_last_epoch': self.use_guess_last_epoch
587
- }
588
- return objective_fn, kwargs
589
-
590
-
591
- # ===============================================================================
592
- #
593
- # DEEP REACTIVE POLICIES
594
- #
595
- # ===============================================================================
596
- def objective_drp(params, kwargs, key, index):
597
-
598
- # transform hyper-parameters to natural space
599
- param_values = [
600
- pmap(params[name])
601
- for (name, (*_, pmap)) in kwargs['hyperparams_dict'].items()
602
- ]
603
-
604
- # unpack hyper-parameters
605
- lr, w, layers, neurons = param_values
606
- if kwargs['verbose']:
607
- print(f'[{index}] key={key}, lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
608
-
609
- # initialize planning algorithm
610
- planner = JaxBackpropPlanner(
611
- rddl=deepcopy(kwargs['rddl']),
612
- plan=JaxDeepReactivePolicy(
613
- topology=[neurons] * layers,
614
- **kwargs['plan_kwargs']),
615
- optimizer_kwargs={'learning_rate': lr},
616
- **kwargs['planner_kwargs'])
617
- policy_hparams = {name: None for name in planner._action_bounds}
618
- model_params = {name: w for name in planner.compiled.model_params}
619
-
620
- # initialize policy
621
- key, subkey = jax.random.split(key)
622
- policy = JaxOfflineController(
623
- planner=planner,
624
- key=subkey,
625
- eval_hyperparams=policy_hparams,
626
- train_on_reset=True,
627
- epochs=kwargs['train_epochs'],
628
- train_seconds=kwargs['timeout_training'],
629
- model_params=model_params,
630
- policy_hyperparams=policy_hparams,
631
- print_summary=False,
632
- print_progress=False,
633
- tqdm_position=index)
634
-
635
- # initialize env for evaluation (need fresh copy to avoid concurrency)
636
- env = RDDLEnv(domain=kwargs['domain'],
637
- instance=kwargs['instance'],
638
- vectorized=True,
639
- enforce_action_constraints=False)
640
-
641
- # perform training
642
- average_reward = 0.0
643
- for trial in range(kwargs['eval_trials']):
644
- key, subkey = jax.random.split(key)
645
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
646
- if kwargs['verbose']:
647
- print(f' [{index}] trial {trial + 1} key={subkey}, '
648
- f'reward={total_reward}', flush=True)
649
- average_reward += total_reward / kwargs['eval_trials']
650
- if kwargs['verbose']:
651
- print(f'[{index}] average reward={average_reward}', flush=True)
652
- return average_reward
653
-
654
-
655
- def power_two_int(x):
656
- return 2 ** int(x)
657
-
658
-
659
- class JaxParameterTuningDRP(JaxParameterTuning):
660
-
661
- def __init__(self, *args,
662
- hyperparams_dict: Dict[str, Tuple[float, float, Callable]]={
663
- 'lr': (-7., 2., power_ten),
664
- 'w': (0., 5., power_ten),
665
- 'layers': (1., 3., int),
666
- 'neurons': (2., 9., power_two_int)
667
- },
668
- **kwargs) -> None:
669
- '''Creates a new tuning class for deep reactive policies.
670
-
671
- :param *args: arguments to pass to parent class
672
- :param hyperparams_dict: same as parent class, but here must contain
673
- learning rate (lr), model weight (w), number of hidden layers (layers)
674
- and number of neurons per hidden layer (neurons)
675
- :param **kwargs: keyword arguments to pass to parent class
676
- '''
677
-
678
- super(JaxParameterTuningDRP, self).__init__(
679
- *args, hyperparams_dict=hyperparams_dict, **kwargs)
680
-
681
- def _pickleable_objective_with_kwargs(self):
682
- objective_fn = objective_drp
683
-
684
- # duplicate planner and plan keyword arguments must be removed
685
- plan_kwargs = self.plan_kwargs.copy()
686
- plan_kwargs.pop('topology', None)
687
-
688
- planner_kwargs = self.planner_kwargs.copy()
689
- planner_kwargs.pop('rddl', None)
690
- planner_kwargs.pop('plan', None)
691
- planner_kwargs.pop('optimizer_kwargs', None)
692
-
693
- kwargs = {
694
- 'rddl': self.env.model,
695
- 'domain': self.env.domain_text,
696
- 'instance': self.env.instance_text,
697
- 'hyperparams_dict': self.hyperparams_dict,
698
- 'timeout_training': self.timeout_training,
699
- 'train_epochs': self.train_epochs,
700
- 'planner_kwargs': planner_kwargs,
701
- 'plan_kwargs': plan_kwargs,
702
- 'verbose': self.verbose,
703
- 'eval_trials': self.eval_trials
704
- }
705
- return objective_fn, kwargs