pyRDDLGym-jax 2.4__py3-none-any.whl → 2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,7 @@ import configparser
39
39
  from enum import Enum
40
40
  from functools import partial
41
41
  import os
42
+ import pickle
42
43
  import sys
43
44
  import time
44
45
  import traceback
@@ -206,6 +207,13 @@ def _load_config(config, args):
206
207
  pgpe_kwargs['optimizer'] = pgpe_optimizer
207
208
  planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
208
209
 
210
+ # preprocessor settings
211
+ preproc_method = planner_args.get('preprocessor', None)
212
+ preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
213
+ if preproc_method is not None:
214
+ planner_args['preprocessor'] = getattr(
215
+ sys.modules[__name__], preproc_method)(**preproc_kwargs)
216
+
209
217
  # optimize call RNG key
210
218
  planner_key = train_args.get('key', None)
211
219
  if planner_key is not None:
@@ -229,13 +237,19 @@ def _load_config(config, args):
229
237
 
230
238
 
231
239
  def load_config(path: str) -> Tuple[Kwargs, ...]:
232
- '''Loads a config file at the specified file path.'''
240
+ '''Loads a config file at the specified file path.
241
+
242
+ :param path: the path of the config file to load and parse
243
+ '''
233
244
  config, args = _parse_config_file(path)
234
245
  return _load_config(config, args)
235
246
 
236
247
 
237
248
  def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
238
- '''Loads config file contents specified explicitly as a string value.'''
249
+ '''Loads config file contents specified explicitly as a string value.
250
+
251
+ :param value: the string in json format containing the config contents to parse
252
+ '''
239
253
  config, args = _parse_config_string(value)
240
254
  return _load_config(config, args)
241
255
 
@@ -258,6 +272,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
258
272
  def __init__(self, *args,
259
273
  logic: Logic=FuzzyLogic(),
260
274
  cpfs_without_grad: Optional[Set[str]]=None,
275
+ print_warnings: bool=True,
261
276
  **kwargs) -> None:
262
277
  '''Creates a new RDDL to Jax compiler, where operations that are not
263
278
  differentiable are converted to approximate forms that have defined gradients.
@@ -268,6 +283,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
268
283
  to customize these operations
269
284
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
270
285
  through gradient trick)
286
+ :param print_warnings: whether to print warnings
271
287
  :param *kwargs: keyword arguments to pass to base compiler
272
288
  '''
273
289
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
@@ -277,6 +293,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
277
293
  if cpfs_without_grad is None:
278
294
  cpfs_without_grad = set()
279
295
  self.cpfs_without_grad = cpfs_without_grad
296
+ self.print_warnings = print_warnings
280
297
 
281
298
  # actions and CPFs must be continuous
282
299
  pvars_cast = set()
@@ -284,7 +301,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
284
301
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
285
302
  if not np.issubdtype(np.result_type(values), np.floating):
286
303
  pvars_cast.add(var)
287
- if pvars_cast:
304
+ if self.print_warnings and pvars_cast:
288
305
  message = termcolor.colored(
289
306
  f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
290
307
  'green')
@@ -314,12 +331,12 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
314
331
  if cpf in self.cpfs_without_grad:
315
332
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
316
333
 
317
- if cpfs_cast:
334
+ if self.print_warnings and cpfs_cast:
318
335
  message = termcolor.colored(
319
336
  f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
320
337
  'green')
321
338
  print(message)
322
- if self.cpfs_without_grad:
339
+ if self.print_warnings and self.cpfs_without_grad:
323
340
  message = termcolor.colored(
324
341
  f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
325
342
  'green')
@@ -333,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
333
350
  return arg
334
351
 
335
352
 
353
+ # ***********************************************************************
354
+ # ALL VERSIONS OF STATE PREPROCESSING FOR DRP
355
+ #
356
+ # - static normalization
357
+ #
358
+ # ***********************************************************************
359
+
360
+
361
+ class Preprocessor(metaclass=ABCMeta):
362
+ '''Base class for all state preprocessors.'''
363
+
364
+ HYPERPARAMS_KEY = 'preprocessor__'
365
+
366
+ def __init__(self) -> None:
367
+ self._initializer = None
368
+ self._update = None
369
+ self._transform = None
370
+
371
+ @property
372
+ def initialize(self):
373
+ return self._initializer
374
+
375
+ @property
376
+ def update(self):
377
+ return self._update
378
+
379
+ @property
380
+ def transform(self):
381
+ return self._transform
382
+
383
+ @abstractmethod
384
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
385
+ pass
386
+
387
+
388
+ class StaticNormalizer(Preprocessor):
389
+ '''Normalize values by box constraints on fluents computed from the RDDL domain.'''
390
+
391
+ def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
392
+ '''Create a new instance of the static normalizer.
393
+
394
+ :param fluent_bounds: optional bounds on fluents to overwrite default values.
395
+ '''
396
+ self.fluent_bounds = fluent_bounds
397
+
398
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
399
+
400
+ # adjust for partial observability
401
+ rddl = compiled.rddl
402
+ if rddl.observ_fluents:
403
+ observed_vars = rddl.observ_fluents
404
+ else:
405
+ observed_vars = rddl.state_fluents
406
+
407
+ # ignore boolean fluents and infinite bounds
408
+ bounded_vars = {}
409
+ for var in observed_vars:
410
+ if rddl.variable_ranges[var] != 'bool':
411
+ lower, upper = compiled.constraints.bounds[var]
412
+ if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
413
+ bounded_vars[var] = (lower, upper)
414
+ user_bounds = self.fluent_bounds.get(var, None)
415
+ if user_bounds is not None:
416
+ bounded_vars[var] = tuple(user_bounds)
417
+
418
+ # initialize to ranges computed by the constraint parser
419
+ def _jax_wrapped_normalizer_init():
420
+ return bounded_vars
421
+ self._initializer = jax.jit(_jax_wrapped_normalizer_init)
422
+
423
+ # static bounds
424
+ def _jax_wrapped_normalizer_update(subs, stats):
425
+ stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
426
+ jnp.asarray(upper, dtype=compiled.REAL))
427
+ for (var, (lower, upper)) in bounded_vars.items()}
428
+ return stats
429
+ self._update = jax.jit(_jax_wrapped_normalizer_update)
430
+
431
+ # apply min max scaling
432
+ def _jax_wrapped_normalizer_transform(subs, stats):
433
+ new_subs = {}
434
+ for (var, values) in subs.items():
435
+ if var in stats:
436
+ lower, upper = stats[var]
437
+ new_dims = jnp.ndim(values) - jnp.ndim(lower)
438
+ lower = lower[(jnp.newaxis,) * new_dims + (...,)]
439
+ upper = upper[(jnp.newaxis,) * new_dims + (...,)]
440
+ new_subs[var] = (values - lower) / (upper - lower)
441
+ else:
442
+ new_subs[var] = values
443
+ return new_subs
444
+ self._transform = jax.jit(_jax_wrapped_normalizer_transform)
445
+
446
+
336
447
  # ***********************************************************************
337
448
  # ALL VERSIONS OF JAX PLANS
338
449
  #
@@ -358,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
358
469
  @abstractmethod
359
470
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
360
471
  _bounds: Bounds,
361
- horizon: int) -> None:
472
+ horizon: int,
473
+ preprocessor: Optional[Preprocessor]=None) -> None:
362
474
  pass
363
475
 
364
476
  @abstractmethod
@@ -436,10 +548,11 @@ class JaxPlan(metaclass=ABCMeta):
436
548
  ~lower_finite & upper_finite,
437
549
  ~lower_finite & ~upper_finite]
438
550
  bounds[name] = (lower, upper)
439
- message = termcolor.colored(
440
- f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
441
- 'green')
442
- print(message)
551
+ if compiled.print_warnings:
552
+ message = termcolor.colored(
553
+ f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
554
+ 'green')
555
+ print(message)
443
556
  return shapes, bounds, bounds_safe, cond_lists
444
557
 
445
558
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -508,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
508
621
 
509
622
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
510
623
  _bounds: Bounds,
511
- horizon: int) -> None:
624
+ horizon: int,
625
+ preprocessor: Optional[Preprocessor]=None) -> None:
512
626
  rddl = compiled.rddl
513
627
 
514
628
  # calculate the correct action box bounds
@@ -519,7 +633,7 @@ class JaxStraightLinePlan(JaxPlan):
519
633
  # action concurrency check
520
634
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
521
635
  use_constraint_satisfaction = allowed_actions < bool_action_count
522
- if use_constraint_satisfaction:
636
+ if compiled.print_warnings and use_constraint_satisfaction:
523
637
  message = termcolor.colored(
524
638
  f'[INFO] SLP will use projected gradient to satisfy '
525
639
  f'max_nondef_actions since total boolean actions '
@@ -596,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
596
710
  return new_params, True
597
711
 
598
712
  # convert softmax action back to action dict
599
- action_sizes = {var: np.prod(shape[1:], dtype=int)
713
+ action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
600
714
  for (var, shape) in shapes.items()
601
715
  if ranges[var] == 'bool'}
602
716
 
@@ -605,7 +719,7 @@ class JaxStraightLinePlan(JaxPlan):
605
719
  start = 0
606
720
  for (name, size) in action_sizes.items():
607
721
  action = output[..., start:start + size]
608
- action = jnp.reshape(action, newshape=shapes[name][1:])
722
+ action = jnp.reshape(action, shapes[name][1:])
609
723
  if noop[name]:
610
724
  action = 1.0 - action
611
725
  actions[name] = action
@@ -680,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
680
794
  scores = []
681
795
  for (var, param) in params.items():
682
796
  if ranges[var] == 'bool':
683
- param_flat = jnp.ravel(param)
797
+ param_flat = jnp.ravel(param, order='C')
684
798
  if noop[var]:
685
799
  if wrap_sigmoid:
686
800
  param_flat = -param_flat
@@ -838,7 +952,7 @@ class JaxStraightLinePlan(JaxPlan):
838
952
 
839
953
  def guess_next_epoch(self, params: Pytree) -> Pytree:
840
954
  next_fn = JaxStraightLinePlan._guess_next_epoch
841
- return jax.tree_map(next_fn, params)
955
+ return jax.tree_util.tree_map(next_fn, params)
842
956
 
843
957
 
844
958
  class JaxDeepReactivePolicy(JaxPlan):
@@ -897,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
897
1011
 
898
1012
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
899
1013
  _bounds: Bounds,
900
- horizon: int) -> None:
1014
+ horizon: int,
1015
+ preprocessor: Optional[Preprocessor]=None) -> None:
901
1016
  rddl = compiled.rddl
902
1017
 
903
1018
  # calculate the correct action box bounds
@@ -928,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
928
1043
  wrap_non_bool = self._wrap_non_bool
929
1044
  init = self._initializer
930
1045
  layers = list(enumerate(zip(self._topology, self._activations)))
931
- layer_sizes = {var: np.prod(shape, dtype=int)
1046
+ layer_sizes = {var: np.prod(shape, dtype=np.int64)
932
1047
  for (var, shape) in shapes.items()}
933
1048
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
934
1049
 
@@ -946,21 +1061,28 @@ class JaxDeepReactivePolicy(JaxPlan):
946
1061
  if ranges[var] != 'bool':
947
1062
  value_size = np.size(values)
948
1063
  if normalize_per_layer and value_size == 1:
949
- message = termcolor.colored(
950
- f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
951
- f'of size 1: setting normalize_per_layer = False.', 'yellow')
952
- print(message)
1064
+ if compiled.print_warnings:
1065
+ message = termcolor.colored(
1066
+ f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
1067
+ f'of size 1: setting normalize_per_layer = False.', 'yellow')
1068
+ print(message)
953
1069
  normalize_per_layer = False
954
1070
  non_bool_dims += value_size
955
1071
  if not normalize_per_layer and non_bool_dims == 1:
956
- message = termcolor.colored(
957
- '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
958
- 'setting normalize = False.', 'yellow')
959
- print(message)
1072
+ if compiled.print_warnings:
1073
+ message = termcolor.colored(
1074
+ '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
1075
+ 'setting normalize = False.', 'yellow')
1076
+ print(message)
960
1077
  normalize = False
961
1078
 
962
1079
  # convert subs dictionary into a state vector to feed to the MLP
963
- def _jax_wrapped_policy_input(subs):
1080
+ def _jax_wrapped_policy_input(subs, hyperparams):
1081
+
1082
+ # optional state preprocessing
1083
+ if preprocessor is not None:
1084
+ stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
1085
+ subs = preprocessor.transform(subs, stats)
964
1086
 
965
1087
  # concatenate all state variables into a single vector
966
1088
  # optionally apply layer norm to each input tensor
@@ -968,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
968
1090
  non_bool_dims = 0
969
1091
  for (var, value) in subs.items():
970
1092
  if var in observed_vars:
971
- state = jnp.ravel(value)
1093
+ state = jnp.ravel(value, order='C')
972
1094
  if ranges[var] == 'bool':
973
1095
  states_bool.append(state)
974
1096
  else:
@@ -997,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
997
1119
  return state
998
1120
 
999
1121
  # predict actions from the policy network for current state
1000
- def _jax_wrapped_policy_network_predict(subs):
1001
- state = _jax_wrapped_policy_input(subs)
1122
+ def _jax_wrapped_policy_network_predict(subs, hyperparams):
1123
+ state = _jax_wrapped_policy_input(subs, hyperparams)
1002
1124
 
1003
1125
  # feed state vector through hidden layers
1004
1126
  hidden = state
@@ -1054,7 +1176,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1054
1176
  for (name, size) in layer_sizes.items():
1055
1177
  if ranges[name] == 'bool':
1056
1178
  action = output[..., start:start + size]
1057
- action = jnp.reshape(action, newshape=shapes[name])
1179
+ action = jnp.reshape(action, shapes[name])
1058
1180
  if noop[name]:
1059
1181
  action = 1.0 - action
1060
1182
  actions[name] = action
@@ -1063,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1063
1185
 
1064
1186
  # train action prediction
1065
1187
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
1066
- actions = predict_fn.apply(params, subs)
1188
+ actions = predict_fn.apply(params, subs, hyperparams)
1067
1189
  if not wrap_non_bool:
1068
1190
  for (var, action) in actions.items():
1069
1191
  if var != bool_key and ranges[var] != 'bool':
@@ -1113,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1113
1235
  subs = {var: value[0, ...]
1114
1236
  for (var, value) in subs.items()
1115
1237
  if var in observed_vars}
1116
- params = predict_fn.init(key, subs)
1238
+ params = predict_fn.init(key, subs, hyperparams)
1117
1239
  return params
1118
1240
 
1119
1241
  self.initializer = _jax_wrapped_drp_init
@@ -1226,6 +1348,7 @@ class PGPE(metaclass=ABCMeta):
1226
1348
 
1227
1349
  @abstractmethod
1228
1350
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1351
+ print_warnings: bool,
1229
1352
  parallel_updates: Optional[int]=None) -> None:
1230
1353
  pass
1231
1354
 
@@ -1322,6 +1445,7 @@ class GaussianPGPE(PGPE):
1322
1445
  )
1323
1446
 
1324
1447
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1448
+ print_warnings: bool,
1325
1449
  parallel_updates: Optional[int]=None) -> None:
1326
1450
  sigma0 = self.init_sigma
1327
1451
  sigma_lo, sigma_hi = self.sigma_range
@@ -1347,7 +1471,7 @@ class GaussianPGPE(PGPE):
1347
1471
 
1348
1472
  def _jax_wrapped_pgpe_init(key, policy_params):
1349
1473
  mu = policy_params
1350
- sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1474
+ sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1351
1475
  pgpe_params = (mu, sigma)
1352
1476
  pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
1353
1477
  r_max = -jnp.inf
@@ -1395,13 +1519,14 @@ class GaussianPGPE(PGPE):
1395
1519
  treedef = jax.tree_util.tree_structure(sigma)
1396
1520
  keys = random.split(key, num=treedef.num_leaves)
1397
1521
  keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1398
- epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1399
- p1 = jax.tree_map(jnp.add, mu, epsilon)
1400
- p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1522
+ epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1523
+ p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
1524
+ p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
1401
1525
  if super_symmetric:
1402
- epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1403
- p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1404
- p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1526
+ epsilon_star = jax.tree_util.tree_map(
1527
+ _jax_wrapped_epsilon_star, sigma, epsilon)
1528
+ p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
1529
+ p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
1405
1530
  else:
1406
1531
  epsilon_star, p3, p4 = epsilon, p1, p2
1407
1532
  return p1, p2, p3, p4, epsilon, epsilon_star
@@ -1469,11 +1594,11 @@ class GaussianPGPE(PGPE):
1469
1594
  r_max = jnp.maximum(r_max, r4)
1470
1595
  else:
1471
1596
  r3, r4 = r1, r2
1472
- grad_mu = jax.tree_map(
1597
+ grad_mu = jax.tree_util.tree_map(
1473
1598
  partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1474
1599
  epsilon, epsilon_star
1475
1600
  )
1476
- grad_sigma = jax.tree_map(
1601
+ grad_sigma = jax.tree_util.tree_map(
1477
1602
  partial(_jax_wrapped_sigma_grad,
1478
1603
  r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1479
1604
  epsilon, epsilon_star, sigma
@@ -1492,7 +1617,7 @@ class GaussianPGPE(PGPE):
1492
1617
  _jax_wrapped_pgpe_grad,
1493
1618
  in_axes=(0, None, None, None, None, None, None, None)
1494
1619
  )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1495
- mu_grad, sigma_grad = jax.tree_map(
1620
+ mu_grad, sigma_grad = jax.tree_util.tree_map(
1496
1621
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1497
1622
  new_r_max = jnp.max(r_maxs)
1498
1623
  return mu_grad, sigma_grad, new_r_max
@@ -1516,7 +1641,7 @@ class GaussianPGPE(PGPE):
1516
1641
  sigma_grad, sigma_state, params=sigma)
1517
1642
  new_mu = optax.apply_updates(mu, mu_updates)
1518
1643
  new_sigma = optax.apply_updates(sigma, sigma_updates)
1519
- new_sigma = jax.tree_map(
1644
+ new_sigma = jax.tree_util.tree_map(
1520
1645
  partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
1521
1646
  return new_mu, new_sigma, new_mu_state, new_sigma_state
1522
1647
 
@@ -1537,7 +1662,7 @@ class GaussianPGPE(PGPE):
1537
1662
  if max_kl is not None:
1538
1663
  old_mu_lr = new_mu_state.hyperparams['learning_rate']
1539
1664
  old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1540
- kl_terms = jax.tree_map(
1665
+ kl_terms = jax.tree_util.tree_map(
1541
1666
  _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1542
1667
  total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1543
1668
  kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
@@ -1618,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1618
1743
  return mu - 0.5 * beta * msv
1619
1744
 
1620
1745
 
1746
+ @jax.jit
1747
+ def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
1748
+ return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
1749
+
1750
+
1751
+ @jax.jit
1752
+ def var_utility(returns: jnp.ndarray, alpha: float) -> float:
1753
+ return jnp.percentile(returns, q=100 * alpha)
1754
+
1755
+
1621
1756
  @jax.jit
1622
1757
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1623
1758
  var = jnp.percentile(returns, q=100 * alpha)
1624
1759
  mask = returns <= var
1625
- weights = mask / jnp.maximum(1, jnp.sum(mask))
1626
- return jnp.sum(returns * weights)
1760
+ return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
1627
1761
 
1628
1762
 
1629
1763
  # set of all currently valid built-in utility functions
@@ -1633,8 +1767,10 @@ UTILITY_LOOKUP = {
1633
1767
  'mean_std': mean_deviation_utility,
1634
1768
  'mean_semivar': mean_semivariance_utility,
1635
1769
  'mean_semidev': mean_semideviation_utility,
1770
+ 'sharpe': sharpe_utility,
1636
1771
  'entropic': entropic_utility,
1637
1772
  'exponential': entropic_utility,
1773
+ 'var': var_utility,
1638
1774
  'cvar': cvar_utility
1639
1775
  }
1640
1776
 
@@ -1672,7 +1808,9 @@ class JaxBackpropPlanner:
1672
1808
  compile_non_fluent_exact: bool=True,
1673
1809
  logger: Optional[Logger]=None,
1674
1810
  dashboard_viz: Optional[Any]=None,
1675
- parallel_updates: Optional[int]=None) -> None:
1811
+ print_warnings: bool=True,
1812
+ parallel_updates: Optional[int]=None,
1813
+ preprocessor: Optional[Preprocessor]=None) -> None:
1676
1814
  '''Creates a new gradient-based algorithm for optimizing action sequences
1677
1815
  (plan) in the given RDDL. Some operations will be converted to their
1678
1816
  differentiable counterparts; the specific operations can be customized
@@ -1712,7 +1850,9 @@ class JaxBackpropPlanner:
1712
1850
  :param logger: to log information about compilation to file
1713
1851
  :param dashboard_viz: optional visualizer object from the environment
1714
1852
  to pass to the dashboard to visualize the policy
1853
+ :param print_warnings: whether to print warnings
1715
1854
  :param parallel_updates: how many optimizers to run independently in parallel
1855
+ :param preprocessor: optional preprocessor for state inputs to plan
1716
1856
  '''
1717
1857
  self.rddl = rddl
1718
1858
  self.plan = plan
@@ -1737,6 +1877,8 @@ class JaxBackpropPlanner:
1737
1877
  self.noise_kwargs = noise_kwargs
1738
1878
  self.pgpe = pgpe
1739
1879
  self.use_pgpe = pgpe is not None
1880
+ self.print_warnings = print_warnings
1881
+ self.preprocessor = preprocessor
1740
1882
 
1741
1883
  # set optimizer
1742
1884
  try:
@@ -1789,7 +1931,11 @@ class JaxBackpropPlanner:
1789
1931
  self._jax_compile_rddl()
1790
1932
  self._jax_compile_optimizer()
1791
1933
 
1792
- def summarize_system(self) -> str:
1934
+ @staticmethod
1935
+ def summarize_system() -> str:
1936
+ '''Returns a string containing information about the system, Python version
1937
+ and jax-related packages that are relevant to the current planner.
1938
+ '''
1793
1939
  try:
1794
1940
  jaxlib_version = jax._src.lib.version_str
1795
1941
  except Exception as _:
@@ -1818,6 +1964,9 @@ r"""
1818
1964
  f'devices: {devices_short}\n')
1819
1965
 
1820
1966
  def summarize_relaxations(self) -> str:
1967
+ '''Returns a summary table containing all non-differentiable operators
1968
+ and their relaxations.
1969
+ '''
1821
1970
  result = ''
1822
1971
  if self.compiled.model_params:
1823
1972
  result += ('Some RDDL operations are non-differentiable '
@@ -1834,6 +1983,9 @@ r"""
1834
1983
  return result
1835
1984
 
1836
1985
  def summarize_hyperparameters(self) -> str:
1986
+ '''Returns a string summarizing the hyper-parameters of the current planner
1987
+ instance.
1988
+ '''
1837
1989
  result = (f'objective hyper-parameters:\n'
1838
1990
  f' utility_fn ={self.utility.__name__}\n'
1839
1991
  f' utility args ={self.utility_kwargs}\n'
@@ -1852,7 +2004,8 @@ r"""
1852
2004
  f' noise_kwargs ={self.noise_kwargs}\n'
1853
2005
  f' batch_size_train ={self.batch_size_train}\n'
1854
2006
  f' batch_size_test ={self.batch_size_test}\n'
1855
- f' parallel_updates ={self.parallel_updates}\n')
2007
+ f' parallel_updates ={self.parallel_updates}\n'
2008
+ f' preprocessor ={self.preprocessor}\n')
1856
2009
  result += str(self.plan)
1857
2010
  if self.use_pgpe:
1858
2011
  result += str(self.pgpe)
@@ -1873,7 +2026,8 @@ r"""
1873
2026
  logger=self.logger,
1874
2027
  use64bit=self.use64bit,
1875
2028
  cpfs_without_grad=self.cpfs_without_grad,
1876
- compile_non_fluent_exact=self.compile_non_fluent_exact
2029
+ compile_non_fluent_exact=self.compile_non_fluent_exact,
2030
+ print_warnings=self.print_warnings
1877
2031
  )
1878
2032
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1879
2033
 
@@ -1887,10 +2041,15 @@ r"""
1887
2041
 
1888
2042
  def _jax_compile_optimizer(self):
1889
2043
 
2044
+ # preprocessor
2045
+ if self.preprocessor is not None:
2046
+ self.preprocessor.compile(self.compiled)
2047
+
1890
2048
  # policy
1891
2049
  self.plan.compile(self.compiled,
1892
2050
  _bounds=self._action_bounds,
1893
- horizon=self.horizon)
2051
+ horizon=self.horizon,
2052
+ preprocessor=self.preprocessor)
1894
2053
  self.train_policy = jax.jit(self.plan.train_policy)
1895
2054
  self.test_policy = jax.jit(self.plan.test_policy)
1896
2055
 
@@ -1898,14 +2057,16 @@ r"""
1898
2057
  train_rollouts = self.compiled.compile_rollouts(
1899
2058
  policy=self.plan.train_policy,
1900
2059
  n_steps=self.horizon,
1901
- n_batch=self.batch_size_train
2060
+ n_batch=self.batch_size_train,
2061
+ cache_path_info=self.preprocessor is not None
1902
2062
  )
1903
2063
  self.train_rollouts = train_rollouts
1904
2064
 
1905
2065
  test_rollouts = self.test_compiled.compile_rollouts(
1906
2066
  policy=self.plan.test_policy,
1907
2067
  n_steps=self.horizon,
1908
- n_batch=self.batch_size_test
2068
+ n_batch=self.batch_size_test,
2069
+ cache_path_info=False
1909
2070
  )
1910
2071
  self.test_rollouts = jax.jit(test_rollouts)
1911
2072
 
@@ -1922,7 +2083,8 @@ r"""
1922
2083
 
1923
2084
  # optimization
1924
2085
  self.update = self._jax_update(train_loss)
1925
- self.pytree_at = jax.jit(lambda tree, i: jax.tree_map(lambda x: x[i], tree))
2086
+ self.pytree_at = jax.jit(
2087
+ lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
1926
2088
 
1927
2089
  # pgpe option
1928
2090
  if self.use_pgpe:
@@ -1930,6 +2092,7 @@ r"""
1930
2092
  loss_fn=test_loss,
1931
2093
  projection=self.plan.projection,
1932
2094
  real_dtype=self.test_compiled.REAL,
2095
+ print_warnings=self.print_warnings,
1933
2096
  parallel_updates=self.parallel_updates
1934
2097
  )
1935
2098
  self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
@@ -2010,7 +2173,7 @@ r"""
2010
2173
  # check if the gradients are all zeros
2011
2174
  def _jax_wrapped_zero_gradients(grad):
2012
2175
  leaves, _ = jax.tree_util.tree_flatten(
2013
- jax.tree_map(partial(jnp.allclose, b=0), grad))
2176
+ jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
2014
2177
  return jnp.all(jnp.asarray(leaves))
2015
2178
 
2016
2179
  # calculate the plan gradient w.r.t. return loss and update optimizer
@@ -2069,7 +2232,7 @@ r"""
2069
2232
  def select_fn(leaf1, leaf2):
2070
2233
  expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
2071
2234
  return jnp.where(expanded_mask, leaf1, leaf2)
2072
- policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
2235
+ policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
2073
2236
  test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
2074
2237
  test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
2075
2238
  expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
@@ -2091,7 +2254,9 @@ r"""
2091
2254
  f'Variable <{name}> in subs argument is not a '
2092
2255
  f'valid p-variable, must be one of '
2093
2256
  f'{set(self.test_compiled.init_values.keys())}.')
2094
- value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
2257
+ value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
2258
+ if value.dtype.type is np.str_:
2259
+ value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
2095
2260
  train_value = np.repeat(value, repeats=n_train, axis=0)
2096
2261
  train_value = np.asarray(train_value, dtype=self.compiled.REAL)
2097
2262
  init_train[name] = train_value
@@ -2121,7 +2286,7 @@ r"""
2121
2286
  x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
2122
2287
  return x
2123
2288
 
2124
- return jax.tree_map(make_batched, pytree)
2289
+ return jax.tree_util.tree_map(make_batched, pytree)
2125
2290
 
2126
2291
  def as_optimization_problem(
2127
2292
  self, key: Optional[random.PRNGKey]=None,
@@ -2165,10 +2330,11 @@ r"""
2165
2330
  train_subs, _ = self._batched_init_subs(subs)
2166
2331
  model_params = self.compiled.model_params
2167
2332
  if policy_hyperparams is None:
2168
- message = termcolor.colored(
2169
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2170
- 'all action-fluents which could be suboptimal.', 'yellow')
2171
- print(message)
2333
+ if self.print_warnings:
2334
+ message = termcolor.colored(
2335
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2336
+ 'all action-fluents which could be suboptimal.', 'yellow')
2337
+ print(message)
2172
2338
  policy_hyperparams = {action: 1.0
2173
2339
  for action in self.rddl.action_fluents}
2174
2340
 
@@ -2318,10 +2484,11 @@ r"""
2318
2484
 
2319
2485
  # cannot run dashboard with parallel updates
2320
2486
  if dashboard is not None and self.parallel_updates is not None:
2321
- message = termcolor.colored(
2322
- '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2323
- 'setting dashboard to None.', 'yellow')
2324
- print(message)
2487
+ if self.print_warnings:
2488
+ message = termcolor.colored(
2489
+ '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2490
+ 'setting dashboard to None.', 'yellow')
2491
+ print(message)
2325
2492
  dashboard = None
2326
2493
 
2327
2494
  # if PRNG key is not provided
@@ -2331,19 +2498,21 @@ r"""
2331
2498
 
2332
2499
  # if policy_hyperparams is not provided
2333
2500
  if policy_hyperparams is None:
2334
- message = termcolor.colored(
2335
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2336
- 'all action-fluents which could be suboptimal.', 'yellow')
2337
- print(message)
2501
+ if self.print_warnings:
2502
+ message = termcolor.colored(
2503
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2504
+ 'all action-fluents which could be suboptimal.', 'yellow')
2505
+ print(message)
2338
2506
  policy_hyperparams = {action: 1.0
2339
2507
  for action in self.rddl.action_fluents}
2340
2508
 
2341
2509
  # if policy_hyperparams is a scalar
2342
2510
  elif isinstance(policy_hyperparams, (int, float, np.number)):
2343
- message = termcolor.colored(
2344
- f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2345
- f'setting this value for all action-fluents.', 'green')
2346
- print(message)
2511
+ if self.print_warnings:
2512
+ message = termcolor.colored(
2513
+ f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2514
+ f'setting this value for all action-fluents.', 'green')
2515
+ print(message)
2347
2516
  hyperparam_value = float(policy_hyperparams)
2348
2517
  policy_hyperparams = {action: hyperparam_value
2349
2518
  for action in self.rddl.action_fluents}
@@ -2352,13 +2521,20 @@ r"""
2352
2521
  elif isinstance(policy_hyperparams, dict):
2353
2522
  for action in self.rddl.action_fluents:
2354
2523
  if action not in policy_hyperparams:
2355
- message = termcolor.colored(
2356
- f'[WARN] policy_hyperparams[{action}] is not set, '
2357
- f'setting 1.0 for missing action-fluents '
2358
- f'which could be suboptimal.', 'yellow')
2359
- print(message)
2524
+ if self.print_warnings:
2525
+ message = termcolor.colored(
2526
+ f'[WARN] policy_hyperparams[{action}] is not set, '
2527
+ f'setting 1.0 for missing action-fluents '
2528
+ f'which could be suboptimal.', 'yellow')
2529
+ print(message)
2360
2530
  policy_hyperparams[action] = 1.0
2361
-
2531
+
2532
+ # initialize preprocessor
2533
+ preproc_key = None
2534
+ if self.preprocessor is not None:
2535
+ preproc_key = self.preprocessor.HYPERPARAMS_KEY
2536
+ policy_hyperparams[preproc_key] = self.preprocessor.initialize()
2537
+
2362
2538
  # print summary of parameters:
2363
2539
  if print_summary:
2364
2540
  print(self.summarize_system())
@@ -2396,7 +2572,7 @@ r"""
2396
2572
  if var not in subs:
2397
2573
  subs[var] = value
2398
2574
  added_pvars_to_subs.append(var)
2399
- if added_pvars_to_subs:
2575
+ if self.print_warnings and added_pvars_to_subs:
2400
2576
  message = termcolor.colored(
2401
2577
  f'[INFO] p-variables {added_pvars_to_subs} is not in '
2402
2578
  f'provided subs, using their initial values.', 'green')
@@ -2485,6 +2661,11 @@ r"""
2485
2661
  subkey, policy_params, policy_hyperparams, train_subs, model_params,
2486
2662
  opt_state, opt_aux)
2487
2663
 
2664
+ # update the preprocessor
2665
+ if self.preprocessor is not None:
2666
+ policy_hyperparams[preproc_key] = self.preprocessor.update(
2667
+ train_log['fluents'], policy_hyperparams[preproc_key])
2668
+
2488
2669
  # evaluate
2489
2670
  test_loss, (test_log, model_params_test) = self.test_loss(
2490
2671
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
@@ -2637,6 +2818,7 @@ r"""
2637
2818
  'model_params': model_params,
2638
2819
  'progress': progress_percent,
2639
2820
  'train_log': train_log,
2821
+ 'policy_hyperparams': policy_hyperparams,
2640
2822
  **test_log
2641
2823
  }
2642
2824
 
@@ -2648,7 +2830,7 @@ r"""
2648
2830
  policy_params, opt_state, opt_aux = self.initialize(
2649
2831
  subkey, policy_hyperparams, train_subs)
2650
2832
  no_progress_count = 0
2651
- if progress_bar is not None:
2833
+ if self.print_warnings and progress_bar is not None:
2652
2834
  message = termcolor.colored(
2653
2835
  f'[INFO] Optimizer restarted at iteration {it} '
2654
2836
  f'due to lack of progress.', 'green')
@@ -2658,7 +2840,7 @@ r"""
2658
2840
 
2659
2841
  # stopping condition reached
2660
2842
  if stopping_rule is not None and stopping_rule.monitor(callback):
2661
- if progress_bar is not None:
2843
+ if self.print_warnings and progress_bar is not None:
2662
2844
  message = termcolor.colored(
2663
2845
  '[SUCC] Stopping rule has been reached.', 'green')
2664
2846
  progress_bar.write(message)
@@ -2699,7 +2881,8 @@ r"""
2699
2881
 
2700
2882
  # summarize and test for convergence
2701
2883
  if print_summary:
2702
- grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
2884
+ grad_norm = jax.tree_util.tree_map(
2885
+ lambda x: np.linalg.norm(x).item(), best_grad)
2703
2886
  diagnosis = self._perform_diagnosis(
2704
2887
  last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
2705
2888
  -best_loss, grad_norm)
@@ -2713,7 +2896,8 @@ r"""
2713
2896
 
2714
2897
  def _perform_diagnosis(self, last_iter_improve,
2715
2898
  train_return, test_return, best_return, grad_norm):
2716
- max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
2899
+ grad_norms = jax.tree_util.tree_leaves(grad_norm)
2900
+ max_grad_norm = max(grad_norms) if grad_norms else np.nan
2717
2901
  grad_is_zero = np.allclose(max_grad_norm, 0)
2718
2902
 
2719
2903
  # divergence if the solution is not finite
@@ -2777,6 +2961,7 @@ r"""
2777
2961
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2778
2962
  weights for sigmoid wrapping boolean actions (optional)
2779
2963
  '''
2964
+ subs = subs.copy()
2780
2965
 
2781
2966
  # check compatibility of the subs dictionary
2782
2967
  for (var, values) in subs.items():
@@ -2795,13 +2980,17 @@ r"""
2795
2980
  if step == 0 and var in self.rddl.observ_fluents:
2796
2981
  subs[var] = self.test_compiled.init_values[var]
2797
2982
  else:
2798
- raise ValueError(
2799
- f'Values {values} assigned to p-variable <{var}> are '
2800
- f'non-numeric of type {dtype}.')
2983
+ if dtype.type is np.str_:
2984
+ prange = self.rddl.variable_ranges[var]
2985
+ subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
2986
+ else:
2987
+ raise ValueError(
2988
+ f'Values {values} assigned to p-variable <{var}> are '
2989
+ f'non-numeric of type {dtype}.')
2801
2990
 
2802
2991
  # cast device arrays to numpy
2803
2992
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2804
- actions = jax.tree_map(np.asarray, actions)
2993
+ actions = jax.tree_util.tree_map(np.asarray, actions)
2805
2994
  return actions
2806
2995
 
2807
2996
 
@@ -2822,8 +3011,9 @@ class JaxOfflineController(BaseAgent):
2822
3011
  def __init__(self, planner: JaxBackpropPlanner,
2823
3012
  key: Optional[random.PRNGKey]=None,
2824
3013
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2825
- params: Optional[Pytree]=None,
3014
+ params: Optional[Union[str, Pytree]]=None,
2826
3015
  train_on_reset: bool=False,
3016
+ save_path: Optional[str]=None,
2827
3017
  **train_kwargs) -> None:
2828
3018
  '''Creates a new JAX offline control policy that is trained once, then
2829
3019
  deployed later.
@@ -2834,8 +3024,10 @@ class JaxOfflineController(BaseAgent):
2834
3024
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
2835
3025
  or whenever sample_action is called
2836
3026
  :param params: use the specified policy parameters instead of calling
2837
- planner.optimize()
3027
+ planner.optimize(); can be a string pointing to a valid file path where params
3028
+ have been saved, or a pytree of parameters
2838
3029
  :param train_on_reset: retrain policy parameters on every episode reset
3030
+ :param save_path: optional path to save parameters to
2839
3031
  :param **train_kwargs: any keyword arguments to be passed to the planner
2840
3032
  for optimization
2841
3033
  '''
@@ -2847,13 +3039,28 @@ class JaxOfflineController(BaseAgent):
2847
3039
  self.train_on_reset = train_on_reset
2848
3040
  self.train_kwargs = train_kwargs
2849
3041
  self.params_given = params is not None
3042
+ self.hyperparams_given = eval_hyperparams is not None
2850
3043
 
3044
+ # load the policy from file
3045
+ if not self.train_on_reset and params is not None and isinstance(params, str):
3046
+ with open(params, 'rb') as file:
3047
+ params = pickle.load(file)
3048
+
3049
+ # train the policy
2851
3050
  self.step = 0
2852
3051
  self.callback = None
2853
3052
  if not self.train_on_reset and not self.params_given:
2854
3053
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2855
3054
  self.callback = callback
2856
3055
  params = callback['best_params']
3056
+ if not self.hyperparams_given:
3057
+ self.eval_hyperparams = callback['policy_hyperparams']
3058
+
3059
+ # save the policy
3060
+ if save_path is not None:
3061
+ with open(save_path, 'wb') as file:
3062
+ pickle.dump(params, file)
3063
+
2857
3064
  self.params = params
2858
3065
 
2859
3066
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
@@ -2865,10 +3072,14 @@ class JaxOfflineController(BaseAgent):
2865
3072
 
2866
3073
  def reset(self) -> None:
2867
3074
  self.step = 0
3075
+
3076
+ # train the policy if required to reset at the start of every episode
2868
3077
  if self.train_on_reset and not self.params_given:
2869
3078
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2870
3079
  self.callback = callback
2871
3080
  self.params = callback['best_params']
3081
+ if not self.hyperparams_given:
3082
+ self.eval_hyperparams = callback['policy_hyperparams']
2872
3083
 
2873
3084
 
2874
3085
  class JaxOnlineController(BaseAgent):
@@ -2901,6 +3112,7 @@ class JaxOnlineController(BaseAgent):
2901
3112
  key = random.PRNGKey(round(time.time() * 1000))
2902
3113
  self.key = key
2903
3114
  self.eval_hyperparams = eval_hyperparams
3115
+ self.hyperparams_given = eval_hyperparams is not None
2904
3116
  self.warm_start = warm_start
2905
3117
  self.train_kwargs = train_kwargs
2906
3118
  self.max_attempts = max_attempts
@@ -2915,18 +3127,24 @@ class JaxOnlineController(BaseAgent):
2915
3127
  attempts = 0
2916
3128
  while attempts < self.max_attempts and callback['iteration'] <= 1:
2917
3129
  attempts += 1
2918
- message = termcolor.colored(
2919
- f'[WARN] JIT compilation dominated the execution time: '
2920
- f'executing the optimizer again on the traced model [attempt {attempts}].',
2921
- 'yellow')
2922
- print(message)
3130
+ if self.planner.print_warnings:
3131
+ message = termcolor.colored(
3132
+ f'[WARN] JIT compilation dominated the execution time: '
3133
+ f'executing the optimizer again on the traced model '
3134
+ f'[attempt {attempts}].', 'yellow')
3135
+ print(message)
2923
3136
  callback = planner.optimize(
2924
- key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2925
-
3137
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2926
3138
  self.callback = callback
2927
3139
  params = callback['best_params']
3140
+ if not self.hyperparams_given:
3141
+ self.eval_hyperparams = callback['policy_hyperparams']
3142
+
3143
+ # get the action from the parameters for the current state
2928
3144
  self.key, subkey = random.split(self.key)
2929
3145
  actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
3146
+
3147
+ # apply warm start for the next epoch
2930
3148
  if self.warm_start:
2931
3149
  self.guess = planner.plan.guess_next_epoch(params)
2932
3150
  return actions