pyRDDLGym-jax 0.2__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,3 @@
1
- __version__ = '0.2'
2
-
3
1
  from ast import literal_eval
4
2
  from collections import deque
5
3
  import configparser
@@ -15,6 +13,7 @@ import os
15
13
  import sys
16
14
  import termcolor
17
15
  import time
16
+ import traceback
18
17
  from tqdm import tqdm
19
18
  from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
20
19
 
@@ -25,14 +24,17 @@ Pytree = Any
25
24
 
26
25
  from pyRDDLGym.core.debug.exception import raise_warning
27
26
 
27
+ from pyRDDLGym_jax import __version__
28
+
28
29
  # try to import matplotlib, if failed then skip plotting
29
30
  try:
30
31
  import matplotlib
31
32
  import matplotlib.pyplot as plt
32
33
  matplotlib.use('TkAgg')
33
34
  except Exception:
34
- raise_warning('matplotlib is not installed, '
35
- 'plotting functionality is disabled.', 'red')
35
+ raise_warning('failed to import matplotlib: '
36
+ 'plotting functionality will be disabled.', 'red')
37
+ traceback.print_exc()
36
38
  plt = None
37
39
 
38
40
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
@@ -113,7 +115,8 @@ def _load_config(config, args):
113
115
  # policy initialization
114
116
  plan_initializer = plan_kwargs.get('initializer', None)
115
117
  if plan_initializer is not None:
116
- initializer = _getattr_any(packages=[initializers], item=plan_initializer)
118
+ initializer = _getattr_any(
119
+ packages=[initializers, hk.initializers], item=plan_initializer)
117
120
  if initializer is None:
118
121
  raise_warning(
119
122
  f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
@@ -130,7 +133,8 @@ def _load_config(config, args):
130
133
  # policy activation
131
134
  plan_activation = plan_kwargs.get('activation', None)
132
135
  if plan_activation is not None:
133
- activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
136
+ activation = _getattr_any(
137
+ packages=[jax.nn, jax.numpy], item=plan_activation)
134
138
  if activation is None:
135
139
  raise_warning(
136
140
  f'Ignoring invalid activation <{plan_activation}>.', 'red')
@@ -217,6 +221,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
217
221
  :param *kwargs: keyword arguments to pass to base compiler
218
222
  '''
219
223
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
224
+
220
225
  self.logic = logic
221
226
  self.logic.set_use64bit(self.use64bit)
222
227
  if cpfs_without_grad is None:
@@ -224,9 +229,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
224
229
  self.cpfs_without_grad = cpfs_without_grad
225
230
 
226
231
  # actions and CPFs must be continuous
227
- raise_warning('Initial values of pvariables will be cast to real.')
232
+ pvars_cast = set()
228
233
  for (var, values) in self.init_values.items():
229
234
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
235
+ if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
236
+ pvars_cast.add(var)
237
+ if pvars_cast:
238
+ raise_warning(f'JAX gradient compiler requires that initial values '
239
+ f'of p-variables {pvars_cast} be cast to float.')
230
240
 
231
241
  # overwrite basic operations with fuzzy ones
232
242
  self.RELATIONAL_OPS = {
@@ -273,20 +283,29 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
273
283
  return _jax_wrapped_stop_grad
274
284
 
275
285
  def _compile_cpfs(self, info):
276
- raise_warning('CPFs outputs will be cast to real.')
286
+ cpfs_cast = set()
277
287
  jax_cpfs = {}
278
288
  for (_, cpfs) in self.levels.items():
279
289
  for cpf in cpfs:
280
290
  _, expr = self.rddl.cpfs[cpf]
281
291
  jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
292
+ if self.rddl.variable_ranges[cpf] != 'real':
293
+ cpfs_cast.add(cpf)
282
294
  if cpf in self.cpfs_without_grad:
283
- raise_warning(f'CPF <{cpf}> stops gradient.')
284
295
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
296
+
297
+ if cpfs_cast:
298
+ raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
299
+ f'{cpfs_cast} be cast to float.')
300
+ if self.cpfs_without_grad:
301
+ raise_warning(f'User requested that gradients not flow '
302
+ f'through CPFs {self.cpfs_without_grad}.')
285
303
  return jax_cpfs
286
304
 
287
305
  def _jax_kron(self, expr, info):
288
306
  if self.logic.verbose:
289
- raise_warning('KronDelta will be ignored.')
307
+ raise_warning('JAX gradient compiler ignores KronDelta '
308
+ 'during compilation.')
290
309
  arg, = expr.args
291
310
  arg = self._jax(arg, info)
292
311
  return arg
@@ -308,7 +327,8 @@ class JaxPlan:
308
327
  self._train_policy = None
309
328
  self._test_policy = None
310
329
  self._projection = None
311
-
330
+ self.bounds = None
331
+
312
332
  def summarize_hyperparameters(self) -> None:
313
333
  pass
314
334
 
@@ -363,7 +383,7 @@ class JaxPlan:
363
383
  # check invalid type
364
384
  if prange not in compiled.JAX_TYPES:
365
385
  raise RDDLTypeError(
366
- f'Invalid range <{prange}. of action-fluent <{name}>, '
386
+ f'Invalid range <{prange}> of action-fluent <{name}>, '
367
387
  f'must be one of {set(compiled.JAX_TYPES.keys())}.')
368
388
 
369
389
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
@@ -385,7 +405,7 @@ class JaxPlan:
385
405
  ~lower_finite & upper_finite,
386
406
  ~lower_finite & ~upper_finite]
387
407
  bounds[name] = (lower, upper)
388
- raise_warning(f'Bounds of action fluent <{name}> set to {bounds[name]}.')
408
+ raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
389
409
  return shapes, bounds, bounds_safe, cond_lists
390
410
 
391
411
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -427,6 +447,7 @@ class JaxStraightLinePlan(JaxPlan):
427
447
  use_new_projection = True
428
448
  '''
429
449
  super(JaxStraightLinePlan, self).__init__()
450
+
430
451
  self._initializer_base = initializer
431
452
  self._initializer = initializer
432
453
  self._wrap_sigmoid = wrap_sigmoid
@@ -437,9 +458,12 @@ class JaxStraightLinePlan(JaxPlan):
437
458
  self._max_constraint_iter = max_constraint_iter
438
459
 
439
460
  def summarize_hyperparameters(self) -> None:
461
+ bounds = '\n '.join(
462
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
440
463
  print(f'policy hyper-parameters:\n'
441
- f' initializer ={type(self._initializer_base).__name__}\n'
464
+ f' initializer ={self._initializer_base}\n'
442
465
  f'constraint-sat strategy (simple):\n'
466
+ f' parsed_action_bounds =\n {bounds}\n'
443
467
  f' wrap_sigmoid ={self._wrap_sigmoid}\n'
444
468
  f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
445
469
  f' wrap_non_bool ={self._wrap_non_bool}\n'
@@ -603,7 +627,7 @@ class JaxStraightLinePlan(JaxPlan):
603
627
  if 1 < allowed_actions < bool_action_count:
604
628
  raise RDDLNotImplementedError(
605
629
  f'Straight-line plans with wrap_softmax currently '
606
- f'do not support max-nondef-actions = {allowed_actions} > 1.')
630
+ f'do not support max-nondef-actions {allowed_actions} > 1.')
607
631
 
608
632
  # potentially apply projection but to non-bool actions only
609
633
  self.projection = _jax_wrapped_slp_project_to_box
@@ -734,14 +758,14 @@ class JaxStraightLinePlan(JaxPlan):
734
758
  for (var, shape) in shapes.items():
735
759
  if ranges[var] != 'bool' or not stack_bool_params:
736
760
  key, subkey = random.split(key)
737
- param = init(subkey, shape, dtype=compiled.REAL)
761
+ param = init(key=subkey, shape=shape, dtype=compiled.REAL)
738
762
  if ranges[var] == 'bool':
739
763
  param += bool_threshold
740
764
  params[var] = param
741
765
  if stack_bool_params:
742
766
  key, subkey = random.split(key)
743
767
  bool_shape = (horizon, bool_action_count)
744
- bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
768
+ bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
745
769
  params[bool_key] = bool_param
746
770
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
747
771
  return params
@@ -765,7 +789,8 @@ class JaxDeepReactivePolicy(JaxPlan):
765
789
  def __init__(self, topology: Optional[Sequence[int]]=None,
766
790
  activation: Activation=jnp.tanh,
767
791
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
768
- normalize: bool=True,
792
+ normalize: bool=False,
793
+ normalize_per_layer: bool=False,
769
794
  normalizer_kwargs: Optional[Kwargs]=None,
770
795
  wrap_non_bool: bool=False) -> None:
771
796
  '''Creates a new deep reactive policy in JAX.
@@ -775,12 +800,15 @@ class JaxDeepReactivePolicy(JaxPlan):
775
800
  :param activation: function to apply after each layer of the policy
776
801
  :param initializer: weight initialization
777
802
  :param normalize: whether to apply layer norm to the inputs
803
+ :param normalize_per_layer: whether to apply layer norm to each input
804
+ individually (only active if normalize is True)
778
805
  :param normalizer_kwargs: if normalize is True, apply additional arguments
779
806
  to layer norm
780
807
  :param wrap_non_bool: whether to wrap real or int action fluent parameters
781
808
  with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
782
809
  '''
783
810
  super(JaxDeepReactivePolicy, self).__init__()
811
+
784
812
  if topology is None:
785
813
  topology = [128, 64]
786
814
  self._topology = topology
@@ -788,22 +816,25 @@ class JaxDeepReactivePolicy(JaxPlan):
788
816
  self._initializer_base = initializer
789
817
  self._initializer = initializer
790
818
  self._normalize = normalize
819
+ self._normalize_per_layer = normalize_per_layer
791
820
  if normalizer_kwargs is None:
792
- normalizer_kwargs = {
793
- 'create_offset': True, 'create_scale': True,
794
- 'name': 'input_norm'
795
- }
821
+ normalizer_kwargs = {'create_offset': True, 'create_scale': True}
796
822
  self._normalizer_kwargs = normalizer_kwargs
797
823
  self._wrap_non_bool = wrap_non_bool
798
824
 
799
825
  def summarize_hyperparameters(self) -> None:
826
+ bounds = '\n '.join(
827
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
800
828
  print(f'policy hyper-parameters:\n'
801
- f' topology ={self._topology}\n'
802
- f' activation_fn ={self._activations[0].__name__}\n'
803
- f' initializer ={type(self._initializer_base).__name__}\n'
804
- f' apply_layer_norm={self._normalize}\n'
805
- f' layer_norm_args ={self._normalizer_kwargs}\n'
806
- f' wrap_non_bool ={self._wrap_non_bool}')
829
+ f' topology ={self._topology}\n'
830
+ f' activation_fn ={self._activations[0].__name__}\n'
831
+ f' initializer ={type(self._initializer_base).__name__}\n'
832
+ f' apply_input_norm ={self._normalize}\n'
833
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
834
+ f' input_norm_args ={self._normalizer_kwargs}\n'
835
+ f'constraint-sat strategy:\n'
836
+ f' parsed_action_bounds=\n {bounds}\n'
837
+ f' wrap_non_bool ={self._wrap_non_bool}')
807
838
 
808
839
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
809
840
  _bounds: Bounds,
@@ -821,7 +852,7 @@ class JaxDeepReactivePolicy(JaxPlan):
821
852
  if 1 < allowed_actions < bool_action_count:
822
853
  raise RDDLNotImplementedError(
823
854
  f'Deep reactive policies currently do not support '
824
- f'max-nondef-actions = {allowed_actions} > 1.')
855
+ f'max-nondef-actions {allowed_actions} > 1.')
825
856
  use_constraint_satisfaction = allowed_actions < bool_action_count
826
857
 
827
858
  noop = {var: (values[0] if isinstance(values, list) else values)
@@ -835,6 +866,7 @@ class JaxDeepReactivePolicy(JaxPlan):
835
866
 
836
867
  ranges = rddl.variable_ranges
837
868
  normalize = self._normalize
869
+ normalize_per_layer = self._normalize_per_layer
838
870
  wrap_non_bool = self._wrap_non_bool
839
871
  init = self._initializer
840
872
  layers = list(enumerate(zip(self._topology, self._activations)))
@@ -842,14 +874,67 @@ class JaxDeepReactivePolicy(JaxPlan):
842
874
  for (var, shape) in shapes.items()}
843
875
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
844
876
 
845
- # predict actions from the policy network for current state
846
- def _jax_wrapped_policy_network_predict(state):
877
+ # inputs for the policy network
878
+ if rddl.observ_fluents:
879
+ observed_vars = rddl.observ_fluents
880
+ else:
881
+ observed_vars = rddl.state_fluents
882
+ input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
883
+
884
+ # catch if input norm is applied to size 1 tensor
885
+ if normalize:
886
+ non_bool_dims = 0
887
+ for (var, values) in observed_vars.items():
888
+ if ranges[var] != 'bool':
889
+ value_size = np.atleast_1d(values).size
890
+ if normalize_per_layer and value_size == 1:
891
+ raise_warning(
892
+ f'Cannot apply layer norm to state-fluent <{var}> '
893
+ f'of size 1: setting normalize_per_layer = False.',
894
+ 'red')
895
+ normalize_per_layer = False
896
+ non_bool_dims += value_size
897
+ if not normalize_per_layer and non_bool_dims == 1:
898
+ raise_warning(
899
+ 'Cannot apply layer norm to state-fluents of total size 1: '
900
+ 'setting normalize = False.', 'red')
901
+ normalize = False
902
+
903
+ # convert subs dictionary into a state vector to feed to the MLP
904
+ def _jax_wrapped_policy_input(subs):
905
+
906
+ # concatenate all state variables into a single vector
907
+ # optionally apply layer norm to each input tensor
908
+ states_bool, states_non_bool = [], []
909
+ non_bool_dims = 0
910
+ for (var, value) in subs.items():
911
+ if var in observed_vars:
912
+ state = jnp.ravel(value)
913
+ if ranges[var] == 'bool':
914
+ states_bool.append(state)
915
+ else:
916
+ if normalize and normalize_per_layer:
917
+ normalizer = hk.LayerNorm(
918
+ axis=-1, param_axis=-1,
919
+ name=f'input_norm_{input_names[var]}',
920
+ **self._normalizer_kwargs)
921
+ state = normalizer(state)
922
+ states_non_bool.append(state)
923
+ non_bool_dims += state.size
924
+ state = jnp.concatenate(states_non_bool + states_bool)
847
925
 
848
- # apply layer norm
849
- if normalize:
926
+ # optionally perform layer normalization on the non-bool inputs
927
+ if normalize and not normalize_per_layer and non_bool_dims:
850
928
  normalizer = hk.LayerNorm(
851
- axis=-1, param_axis=-1, **self._normalizer_kwargs)
852
- state = normalizer(state)
929
+ axis=-1, param_axis=-1, name='input_norm',
930
+ **self._normalizer_kwargs)
931
+ normalized = normalizer(state[:non_bool_dims])
932
+ state = state.at[:non_bool_dims].set(normalized)
933
+ return state
934
+
935
+ # predict actions from the policy network for current state
936
+ def _jax_wrapped_policy_network_predict(subs):
937
+ state = _jax_wrapped_policy_input(subs)
853
938
 
854
939
  # feed state vector through hidden layers
855
940
  hidden = state
@@ -913,25 +998,9 @@ class JaxDeepReactivePolicy(JaxPlan):
913
998
  start += size
914
999
  return actions
915
1000
 
916
- if rddl.observ_fluents:
917
- observed_vars = rddl.observ_fluents
918
- else:
919
- observed_vars = rddl.state_fluents
920
-
921
- # state is concatenated into single tensor
922
- def _jax_wrapped_subs_to_state(subs):
923
- subs = {var: value
924
- for (var, value) in subs.items()
925
- if var in observed_vars}
926
- flat_subs = jax.tree_map(jnp.ravel, subs)
927
- states = list(flat_subs.values())
928
- state = jnp.concatenate(states)
929
- return state
930
-
931
1001
  # train action prediction
932
1002
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
933
- state = _jax_wrapped_subs_to_state(subs)
934
- actions = predict_fn.apply(params, state)
1003
+ actions = predict_fn.apply(params, subs)
935
1004
  if not wrap_non_bool:
936
1005
  for (var, action) in actions.items():
937
1006
  if var != bool_key and ranges[var] != 'bool':
@@ -982,8 +1051,7 @@ class JaxDeepReactivePolicy(JaxPlan):
982
1051
  subs = {var: value[0, ...]
983
1052
  for (var, value) in subs.items()
984
1053
  if var in observed_vars}
985
- state = _jax_wrapped_subs_to_state(subs)
986
- params = predict_fn.init(key, state)
1054
+ params = predict_fn.init(key, subs)
987
1055
  return params
988
1056
 
989
1057
  self.initializer = _jax_wrapped_drp_init
@@ -1021,46 +1089,72 @@ class RollingMean:
1021
1089
  class JaxPlannerPlot:
1022
1090
  '''Supports plotting and visualization of a JAX policy in real time.'''
1023
1091
 
1024
- def __init__(self, rddl: RDDLPlanningModel, horizon: int) -> None:
1025
- self._fig, axes = plt.subplots(1 + len(rddl.action_fluents))
1092
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1093
+ show_violin: bool=True, show_action: bool=True) -> None:
1094
+ '''Creates a new planner visualizer.
1095
+
1096
+ :param rddl: the planning model to optimize
1097
+ :param horizon: the lookahead or planning horizon
1098
+ :param show_violin: whether to show the distribution of batch losses
1099
+ :param show_action: whether to show heatmaps of the action fluents
1100
+ '''
1101
+ num_plots = 1
1102
+ if show_violin:
1103
+ num_plots += 1
1104
+ if show_action:
1105
+ num_plots += len(rddl.action_fluents)
1106
+ self._fig, axes = plt.subplots(num_plots)
1107
+ if num_plots == 1:
1108
+ axes = [axes]
1026
1109
 
1027
1110
  # prepare the loss plot
1028
1111
  self._loss_ax = axes[0]
1029
1112
  self._loss_ax.autoscale(enable=True)
1030
- self._loss_ax.set_xlabel('decision epoch')
1113
+ self._loss_ax.set_xlabel('training time')
1031
1114
  self._loss_ax.set_ylabel('loss value')
1032
1115
  self._loss_plot = self._loss_ax.plot(
1033
1116
  [], [], linestyle=':', marker='o', markersize=2)[0]
1034
1117
  self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1035
1118
 
1119
+ # prepare the violin plot
1120
+ if show_violin:
1121
+ self._hist_ax = axes[1]
1122
+ else:
1123
+ self._hist_ax = None
1124
+
1036
1125
  # prepare the action plots
1037
- self._action_ax = {name: axes[idx + 1]
1038
- for (idx, name) in enumerate(rddl.action_fluents)}
1039
- self._action_plots = {}
1040
- for name in rddl.action_fluents:
1041
- ax = self._action_ax[name]
1042
- if rddl.variable_ranges[name] == 'bool':
1043
- vmin, vmax = 0.0, 1.0
1044
- else:
1045
- vmin, vmax = None, None
1046
- action_dim = 1
1047
- for dim in rddl.object_counts(rddl.variable_params[name]):
1048
- action_dim *= dim
1049
- action_plot = ax.pcolormesh(
1050
- np.zeros((action_dim, horizon)),
1051
- cmap='seismic', vmin=vmin, vmax=vmax)
1052
- ax.set_aspect('auto')
1053
- ax.set_xlabel('decision epoch')
1054
- ax.set_ylabel(name)
1055
- plt.colorbar(action_plot, ax=ax)
1056
- self._action_plots[name] = action_plot
1057
- self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1058
- for (name, ax) in self._action_ax.items()}
1059
-
1126
+ if show_action:
1127
+ self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1128
+ for (idx, name) in enumerate(rddl.action_fluents)}
1129
+ self._action_plots = {}
1130
+ for name in rddl.action_fluents:
1131
+ ax = self._action_ax[name]
1132
+ if rddl.variable_ranges[name] == 'bool':
1133
+ vmin, vmax = 0.0, 1.0
1134
+ else:
1135
+ vmin, vmax = None, None
1136
+ action_dim = 1
1137
+ for dim in rddl.object_counts(rddl.variable_params[name]):
1138
+ action_dim *= dim
1139
+ action_plot = ax.pcolormesh(
1140
+ np.zeros((action_dim, horizon)),
1141
+ cmap='seismic', vmin=vmin, vmax=vmax)
1142
+ ax.set_aspect('auto')
1143
+ ax.set_xlabel('decision epoch')
1144
+ ax.set_ylabel(name)
1145
+ plt.colorbar(action_plot, ax=ax)
1146
+ self._action_plots[name] = action_plot
1147
+ self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1148
+ for (name, ax) in self._action_ax.items()}
1149
+ else:
1150
+ self._action_ax = None
1151
+ self._action_plots = None
1152
+ self._action_back = None
1153
+
1060
1154
  plt.tight_layout()
1061
1155
  plt.show(block=False)
1062
1156
 
1063
- def redraw(self, xticks, losses, actions) -> None:
1157
+ def redraw(self, xticks, losses, actions, returns) -> None:
1064
1158
 
1065
1159
  # draw the loss curve
1066
1160
  self._fig.canvas.restore_region(self._loss_back)
@@ -1071,21 +1165,30 @@ class JaxPlannerPlot:
1071
1165
  self._loss_ax.draw_artist(self._loss_plot)
1072
1166
  self._fig.canvas.blit(self._loss_ax.bbox)
1073
1167
 
1168
+ # draw the violin plot
1169
+ if self._hist_ax is not None:
1170
+ self._hist_ax.clear()
1171
+ self._hist_ax.set_xlabel('loss value')
1172
+ self._hist_ax.set_ylabel('density')
1173
+ self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1174
+
1074
1175
  # draw the actions
1075
- for (name, values) in actions.items():
1076
- values = np.mean(values, axis=0, dtype=float)
1077
- values = np.reshape(values, newshape=(values.shape[0], -1)).T
1078
- self._fig.canvas.restore_region(self._action_back[name])
1079
- self._action_plots[name].set_array(values)
1080
- self._action_ax[name].draw_artist(self._action_plots[name])
1081
- self._fig.canvas.blit(self._action_ax[name].bbox)
1082
- self._action_plots[name].set_clim([np.min(values), np.max(values)])
1176
+ if self._action_ax is not None:
1177
+ for (name, values) in actions.items():
1178
+ values = np.mean(values, axis=0, dtype=float)
1179
+ values = np.reshape(values, newshape=(values.shape[0], -1)).T
1180
+ self._fig.canvas.restore_region(self._action_back[name])
1181
+ self._action_plots[name].set_array(values)
1182
+ self._action_ax[name].draw_artist(self._action_plots[name])
1183
+ self._fig.canvas.blit(self._action_ax[name].bbox)
1184
+ self._action_plots[name].set_clim([np.min(values), np.max(values)])
1185
+
1083
1186
  self._fig.canvas.draw()
1084
1187
  self._fig.canvas.flush_events()
1085
1188
 
1086
1189
  def close(self) -> None:
1087
1190
  plt.close(self._fig)
1088
- del self._loss_ax, self._action_ax, \
1191
+ del self._loss_ax, self._hist_ax, self._action_ax, \
1089
1192
  self._loss_plot, self._action_plots, self._fig, \
1090
1193
  self._loss_back, self._action_back
1091
1194
 
@@ -1099,9 +1202,9 @@ class JaxPlannerStatus(Enum):
1099
1202
  NORMAL = 0
1100
1203
  NO_PROGRESS = 1
1101
1204
  PRECONDITION_POSSIBLY_UNSATISFIED = 2
1102
- TIME_BUDGET_REACHED = 3
1103
- ITER_BUDGET_REACHED = 4
1104
- INVALID_GRADIENT = 5
1205
+ INVALID_GRADIENT = 3
1206
+ TIME_BUDGET_REACHED = 4
1207
+ ITER_BUDGET_REACHED = 5
1105
1208
 
1106
1209
  def is_failure(self) -> bool:
1107
1210
  return self.value >= 3
@@ -1249,26 +1352,27 @@ class JaxBackpropPlanner:
1249
1352
  f'JAX Planner version {__version__}\n'
1250
1353
  f'Python {sys.version}\n'
1251
1354
  f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1355
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1252
1356
  f'numpy {np.__version__}\n'
1253
1357
  f'devices: {devices_short}\n')
1254
1358
 
1255
1359
  def summarize_hyperparameters(self) -> None:
1256
1360
  print(f'objective hyper-parameters:\n'
1257
- f' utility_fn ={self.utility.__name__}\n'
1258
- f' utility args ={self.utility_kwargs}\n'
1259
- f' use_symlog ={self.use_symlog_reward}\n'
1260
- f' lookahead ={self.horizon}\n'
1261
- f' action_bounds ={self._action_bounds}\n'
1262
- f' fuzzy logic type={type(self.logic).__name__}\n'
1263
- f' nonfluents exact={self.compile_non_fluent_exact}\n'
1264
- f' cpfs_no_gradient={self.cpfs_without_grad}\n'
1361
+ f' utility_fn ={self.utility.__name__}\n'
1362
+ f' utility args ={self.utility_kwargs}\n'
1363
+ f' use_symlog ={self.use_symlog_reward}\n'
1364
+ f' lookahead ={self.horizon}\n'
1365
+ f' user_action_bounds={self._action_bounds}\n'
1366
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1367
+ f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1368
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1265
1369
  f'optimizer hyper-parameters:\n'
1266
- f' use_64_bit ={self.use64bit}\n'
1267
- f' optimizer ={self._optimizer_name.__name__}\n'
1268
- f' optimizer args ={self._optimizer_kwargs}\n'
1269
- f' clip_gradient ={self.clip_grad}\n'
1270
- f' batch_size_train={self.batch_size_train}\n'
1271
- f' batch_size_test ={self.batch_size_test}')
1370
+ f' use_64_bit ={self.use64bit}\n'
1371
+ f' optimizer ={self._optimizer_name.__name__}\n'
1372
+ f' optimizer args ={self._optimizer_kwargs}\n'
1373
+ f' clip_gradient ={self.clip_grad}\n'
1374
+ f' batch_size_train ={self.batch_size_train}\n'
1375
+ f' batch_size_test ={self.batch_size_test}')
1272
1376
  self.plan.summarize_hyperparameters()
1273
1377
  self.logic.summarize_hyperparameters()
1274
1378
 
@@ -1310,6 +1414,7 @@ class JaxBackpropPlanner:
1310
1414
  policy=self.plan.train_policy,
1311
1415
  n_steps=self.horizon,
1312
1416
  n_batch=self.batch_size_train)
1417
+ self.train_rollouts = train_rollouts
1313
1418
 
1314
1419
  test_rollouts = self.test_compiled.compile_rollouts(
1315
1420
  policy=self.plan.test_policy,
@@ -1417,17 +1522,106 @@ class JaxBackpropPlanner:
1417
1522
 
1418
1523
  return init_train, init_test
1419
1524
 
1525
+ def as_optimization_problem(
1526
+ self, key: Optional[random.PRNGKey]=None,
1527
+ policy_hyperparams: Optional[Pytree]=None,
1528
+ loss_function_updates_key: bool=True,
1529
+ grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
1530
+ '''Returns a function that computes the loss and a function that
1531
+ computes gradient of the return as a 1D vector given a 1D representation
1532
+ of policy parameters. These functions are designed to be compatible with
1533
+ off-the-shelf optimizers such as scipy.
1534
+
1535
+ Also returns the initial parameter vector to seed an optimizer,
1536
+ as well as a mapping that recovers the parameter pytree from the vector.
1537
+ The PRNG key is updated internally starting from the optional given key.
1538
+
1539
+ Constraints on actions, if they are required, cannot be constructed
1540
+ automatically in the general case. The user should build constraints
1541
+ for each problem in the format required by the downstream optimizer.
1542
+
1543
+ :param key: JAX PRNG key (derived from clock if not provided)
1544
+ :param policy_hyperparameters: hyper-parameters for the policy/plan,
1545
+ such as weights for sigmoid wrapping boolean actions (defaults to 1
1546
+ for all action-fluents if not provided)
1547
+ :param loss_function_updates_key: if True, the loss function
1548
+ updates the PRNG key internally independently of the grad function
1549
+ :param grad_function_updates_key: if True, the gradient function
1550
+ updates the PRNG key internally independently of the loss function.
1551
+ '''
1552
+
1553
+ # if PRNG key is not provided
1554
+ if key is None:
1555
+ key = random.PRNGKey(round(time.time() * 1000))
1556
+
1557
+ # initialize the initial fluents, model parameters, policy hyper-params
1558
+ subs = self.test_compiled.init_values
1559
+ train_subs, _ = self._batched_init_subs(subs)
1560
+ model_params = self.compiled.model_params
1561
+ if policy_hyperparams is None:
1562
+ raise_warning('policy_hyperparams is not set, setting 1.0 for '
1563
+ 'all action-fluents which could be suboptimal.')
1564
+ policy_hyperparams = {action: 1.0
1565
+ for action in self.rddl.action_fluents}
1566
+
1567
+ # initialize the policy parameters
1568
+ params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
1569
+ guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
1570
+ guess_1d = np.asarray(guess_1d)
1571
+
1572
+ # computes the training loss function and its 1D gradient
1573
+ loss_fn = self._jax_loss(self.train_rollouts)
1574
+
1575
+ @jax.jit
1576
+ def _loss_with_key(key, params_1d):
1577
+ policy_params = unravel_fn(params_1d)
1578
+ loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1579
+ train_subs, model_params)
1580
+ return loss_val
1581
+
1582
+ @jax.jit
1583
+ def _grad_with_key(key, params_1d):
1584
+ policy_params = unravel_fn(params_1d)
1585
+ grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1586
+ grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1587
+ train_subs, model_params)
1588
+ grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1589
+ return grad_1d
1590
+
1591
+ def _loss_function(params_1d):
1592
+ nonlocal key
1593
+ if loss_function_updates_key:
1594
+ key, subkey = random.split(key)
1595
+ else:
1596
+ subkey = key
1597
+ loss_val = _loss_with_key(subkey, params_1d)
1598
+ loss_val = float(loss_val)
1599
+ return loss_val
1600
+
1601
+ def _grad_function(params_1d):
1602
+ nonlocal key
1603
+ if grad_function_updates_key:
1604
+ key, subkey = random.split(key)
1605
+ else:
1606
+ subkey = key
1607
+ grad = _grad_with_key(subkey, params_1d)
1608
+ grad = np.asarray(grad)
1609
+ return grad
1610
+
1611
+ return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1612
+
1420
1613
  # ===========================================================================
1421
1614
  # OPTIMIZE API
1422
1615
  # ===========================================================================
1423
1616
 
1424
1617
  def optimize(self, *args, **kwargs) -> Dict[str, Any]:
1425
- ''' Compute an optimal policy or plan. Return the callback from training.
1618
+ '''Compute an optimal policy or plan. Return the callback from training.
1426
1619
 
1427
1620
  :param key: JAX PRNG key (derived from clock if not provided)
1428
1621
  :param epochs: the maximum number of steps of gradient descent
1429
1622
  :param train_seconds: total time allocated for gradient descent
1430
1623
  :param plot_step: frequency to plot the plan and save result to disk
1624
+ :param plot_kwargs: additional arguments to pass to the plotter
1431
1625
  :param model_params: optional model-parameters to override default
1432
1626
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1433
1627
  weights for sigmoid wrapping boolean actions
@@ -1435,7 +1629,9 @@ class JaxBackpropPlanner:
1435
1629
  their values: if None initializes all variables from the RDDL instance
1436
1630
  :param guess: initial policy parameters: if None will use the initializer
1437
1631
  specified in this instance
1438
- :param verbose: not print (0), print summary (1), print progress (2)
1632
+ :param print_summary: whether to print planner header, parameter
1633
+ summary, and diagnosis
1634
+ :param print_progress: whether to print the progress bar during training
1439
1635
  :param test_rolling_window: the test return is averaged on a rolling
1440
1636
  window of the past test_rolling_window returns when updating the best
1441
1637
  parameters found so far
@@ -1461,11 +1657,13 @@ class JaxBackpropPlanner:
1461
1657
  epochs: int=999999,
1462
1658
  train_seconds: float=120.,
1463
1659
  plot_step: Optional[int]=None,
1660
+ plot_kwargs: Optional[Dict[str, Any]]=None,
1464
1661
  model_params: Optional[Dict[str, Any]]=None,
1465
1662
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1466
1663
  subs: Optional[Dict[str, Any]]=None,
1467
1664
  guess: Optional[Pytree]=None,
1468
- verbose: int=2,
1665
+ print_summary: bool=True,
1666
+ print_progress: bool=True,
1469
1667
  test_rolling_window: int=10,
1470
1668
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1471
1669
  '''Returns a generator for computing an optimal policy or plan.
@@ -1476,20 +1674,22 @@ class JaxBackpropPlanner:
1476
1674
  :param epochs: the maximum number of steps of gradient descent
1477
1675
  :param train_seconds: total time allocated for gradient descent
1478
1676
  :param plot_step: frequency to plot the plan and save result to disk
1677
+ :param plot_kwargs: additional arguments to pass to the plotter
1479
1678
  :param model_params: optional model-parameters to override default
1480
1679
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1481
1680
  weights for sigmoid wrapping boolean actions
1482
1681
  :param subs: dictionary mapping initial state and non-fluents to
1483
1682
  their values: if None initializes all variables from the RDDL instance
1484
1683
  :param guess: initial policy parameters: if None will use the initializer
1485
- specified in this instance
1486
- :param verbose: not print (0), print summary (1), print progress (2)
1684
+ specified in this instance
1685
+ :param print_summary: whether to print planner header, parameter
1686
+ summary, and diagnosis
1687
+ :param print_progress: whether to print the progress bar during training
1487
1688
  :param test_rolling_window: the test return is averaged on a rolling
1488
1689
  window of the past test_rolling_window returns when updating the best
1489
1690
  parameters found so far
1490
1691
  :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1491
1692
  '''
1492
- verbose = int(verbose)
1493
1693
  start_time = time.time()
1494
1694
  elapsed_outside_loop = 0
1495
1695
 
@@ -1513,7 +1713,7 @@ class JaxBackpropPlanner:
1513
1713
  for action in self.rddl.action_fluents}
1514
1714
 
1515
1715
  # print summary of parameters:
1516
- if verbose >= 1:
1716
+ if print_summary:
1517
1717
  self._summarize_system()
1518
1718
  self.summarize_hyperparameters()
1519
1719
  print(f'optimize() call hyper-parameters:\n'
@@ -1526,8 +1726,10 @@ class JaxBackpropPlanner:
1526
1726
  f' provide_param_guess={guess is not None}\n'
1527
1727
  f' test_rolling_window={test_rolling_window}\n'
1528
1728
  f' plot_frequency ={plot_step}\n'
1529
- f' verbose ={verbose}\n')
1530
- if verbose >= 2 and self.compiled.relaxations:
1729
+ f' plot_kwargs ={plot_kwargs}\n'
1730
+ f' print_summary ={print_summary}\n'
1731
+ f' print_progress ={print_progress}\n')
1732
+ if self.compiled.relaxations:
1531
1733
  print('Some RDDL operations are non-differentiable, '
1532
1734
  'replacing them with differentiable relaxations:')
1533
1735
  print(self.compiled.summarize_model_relaxations())
@@ -1549,7 +1751,7 @@ class JaxBackpropPlanner:
1549
1751
  'from the RDDL files.')
1550
1752
  train_subs, test_subs = self._batched_init_subs(subs)
1551
1753
 
1552
- # initialize, model parameters
1754
+ # initialize model parameters
1553
1755
  if model_params is None:
1554
1756
  model_params = self.compiled.model_params
1555
1757
  model_params_test = self.test_compiled.model_params
@@ -1575,12 +1777,14 @@ class JaxBackpropPlanner:
1575
1777
  if plot_step is None or plot_step <= 0 or plt is None:
1576
1778
  plot = None
1577
1779
  else:
1578
- plot = JaxPlannerPlot(self.rddl, self.horizon)
1780
+ if plot_kwargs is None:
1781
+ plot_kwargs = {}
1782
+ plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1579
1783
  xticks, loss_values = [], []
1580
1784
 
1581
1785
  # training loop
1582
1786
  iters = range(epochs)
1583
- if verbose >= 2:
1787
+ if print_progress:
1584
1788
  iters = tqdm(iters, total=100, position=tqdm_position)
1585
1789
 
1586
1790
  for it in iters:
@@ -1588,9 +1792,18 @@ class JaxBackpropPlanner:
1588
1792
 
1589
1793
  # update the parameters of the plan
1590
1794
  key, subkey = random.split(key)
1591
- policy_params, converged, opt_state, opt_aux, train_loss, train_log = \
1795
+ policy_params, converged, opt_state, opt_aux, \
1796
+ train_loss, train_log = \
1592
1797
  self.update(subkey, policy_params, policy_hyperparams,
1593
1798
  train_subs, model_params, opt_state, opt_aux)
1799
+
1800
+ # no progress
1801
+ grad_norm_zero, _ = jax.tree_util.tree_flatten(
1802
+ jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1803
+ if np.all(grad_norm_zero):
1804
+ status = JaxPlannerStatus.NO_PROGRESS
1805
+
1806
+ # constraint satisfaction problem
1594
1807
  if not np.all(converged):
1595
1808
  raise_warning(
1596
1809
  'Projected gradient method for satisfying action concurrency '
@@ -1598,13 +1811,18 @@ class JaxBackpropPlanner:
1598
1811
  'invalid for the current instance.', 'red')
1599
1812
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1600
1813
 
1601
- # evaluate losses
1814
+ # numerical error
1815
+ if not np.isfinite(train_loss):
1816
+ raise_warning(
1817
+ f'Aborting JAX planner due to invalid train loss {train_loss}.',
1818
+ 'red')
1819
+ status = JaxPlannerStatus.INVALID_GRADIENT
1820
+
1821
+ # evaluate test losses and record best plan so far
1602
1822
  test_loss, log = self.test_loss(
1603
1823
  subkey, policy_params, policy_hyperparams,
1604
1824
  test_subs, model_params_test)
1605
1825
  test_loss = rolling_test_loss.update(test_loss)
1606
-
1607
- # record the best plan so far
1608
1826
  if test_loss < best_loss:
1609
1827
  best_params, best_loss, best_grad = \
1610
1828
  policy_params, test_loss, train_log['grad']
@@ -1617,11 +1835,12 @@ class JaxBackpropPlanner:
1617
1835
  action_values = {name: values
1618
1836
  for (name, values) in log['fluents'].items()
1619
1837
  if name in self.rddl.action_fluents}
1620
- plot.redraw(xticks, loss_values, action_values)
1838
+ returns = -np.sum(np.asarray(log['reward']), axis=1)
1839
+ plot.redraw(xticks, loss_values, action_values, returns)
1621
1840
 
1622
1841
  # if the progress bar is used
1623
1842
  elapsed = time.time() - start_time - elapsed_outside_loop
1624
- if verbose >= 2:
1843
+ if print_progress:
1625
1844
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1626
1845
  iters.set_description(
1627
1846
  f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
@@ -1633,19 +1852,6 @@ class JaxBackpropPlanner:
1633
1852
  if it >= epochs - 1:
1634
1853
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1635
1854
 
1636
- # numerical error
1637
- if not np.isfinite(train_loss):
1638
- raise_warning(
1639
- f'Aborting JAX planner due to invalid train loss {train_loss}.',
1640
- 'red')
1641
- status = JaxPlannerStatus.INVALID_GRADIENT
1642
-
1643
- # no progress
1644
- grad_norm_zero, _ = jax.tree_util.tree_flatten(
1645
- jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1646
- if np.all(grad_norm_zero):
1647
- status = JaxPlannerStatus.NO_PROGRESS
1648
-
1649
1855
  # return a callback
1650
1856
  start_time_outside = time.time()
1651
1857
  yield {
@@ -1671,7 +1877,7 @@ class JaxBackpropPlanner:
1671
1877
  break
1672
1878
 
1673
1879
  # release resources
1674
- if verbose >= 2:
1880
+ if print_progress:
1675
1881
  iters.close()
1676
1882
  if plot is not None:
1677
1883
  plot.close()
@@ -1688,7 +1894,7 @@ class JaxBackpropPlanner:
1688
1894
  f'during test evaluation:\n{messages}', 'red')
1689
1895
 
1690
1896
  # summarize and test for convergence
1691
- if verbose >= 1:
1897
+ if print_summary:
1692
1898
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1693
1899
  diagnosis = self._perform_diagnosis(
1694
1900
  last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
@@ -1778,17 +1984,19 @@ class JaxBackpropPlanner:
1778
1984
  raise ValueError(f'State dictionary passed to the JAX policy is '
1779
1985
  f'grounded, since it contains the key <{var}>, '
1780
1986
  f'but a vectorized environment is required: '
1781
- f'please make sure vectorized=True in the RDDLEnv.')
1987
+ f'make sure vectorized = True in the RDDLEnv.')
1782
1988
 
1783
1989
  # must be numeric array
1784
1990
  # exception is for POMDPs at 1st epoch when observ-fluents are None
1785
- if not jnp.issubdtype(values.dtype, jnp.number) \
1786
- and not jnp.issubdtype(values.dtype, jnp.bool_):
1991
+ dtype = np.atleast_1d(values).dtype
1992
+ if not jnp.issubdtype(dtype, jnp.number) \
1993
+ and not jnp.issubdtype(dtype, jnp.bool_):
1787
1994
  if step == 0 and var in self.rddl.observ_fluents:
1788
1995
  subs[var] = self.test_compiled.init_values[var]
1789
1996
  else:
1790
- raise ValueError(f'Values assigned to pvariable {var} are '
1791
- f'non-numeric of type {values.dtype}: {values}.')
1997
+ raise ValueError(
1998
+ f'Values {values} assigned to p-variable <{var}> are '
1999
+ f'non-numeric of type {dtype}.')
1792
2000
 
1793
2001
  # cast device arrays to numpy
1794
2002
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
@@ -1801,8 +2009,6 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1801
2009
  linear search gradient descent, with the Armijo condition.'''
1802
2010
 
1803
2011
  def __init__(self, *args,
1804
- optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
1805
- optimizer_kwargs: Kwargs={'learning_rate': 1.0},
1806
2012
  decay: float=0.8,
1807
2013
  c: float=0.1,
1808
2014
  step_max: float=1.0,
@@ -1825,11 +2031,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1825
2031
  raise_warning('clip_grad parameter conflicts with '
1826
2032
  'line search planner and will be ignored.', 'red')
1827
2033
  del kwargs['clip_grad']
1828
- super(JaxLineSearchPlanner, self).__init__(
1829
- *args,
1830
- optimizer=optimizer,
1831
- optimizer_kwargs=optimizer_kwargs,
1832
- **kwargs)
2034
+ super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
1833
2035
 
1834
2036
  def summarize_hyperparameters(self) -> None:
1835
2037
  super(JaxLineSearchPlanner, self).summarize_hyperparameters()
@@ -1878,7 +2080,8 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1878
2080
  step = lrmax / decay
1879
2081
  f_step = np.inf
1880
2082
  best_f, best_step, best_params, best_state = np.inf, None, None, None
1881
- while f_step > f - c * step * gnorm2 and step * decay >= lrmin:
2083
+ while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2084
+ or not trials:
1882
2085
  trials += 1
1883
2086
  step *= decay
1884
2087
  f_step, new_params, new_state = _jax_wrapped_line_search_trial(
@@ -1918,7 +2121,7 @@ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1918
2121
 
1919
2122
  @jax.jit
1920
2123
  def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1921
- return jnp.mean(returns) - (beta / 2.0) * jnp.var(returns)
2124
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
1922
2125
 
1923
2126
 
1924
2127
  @jax.jit
@@ -1986,7 +2189,8 @@ class JaxOfflineController(BaseAgent):
1986
2189
  def reset(self) -> None:
1987
2190
  self.step = 0
1988
2191
  if self.train_on_reset and not self.params_given:
1989
- self.params = self.planner.optimize(key=self.key, **self.train_kwargs)
2192
+ callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2193
+ self.params = callback['best_params']
1990
2194
 
1991
2195
 
1992
2196
  class JaxOnlineController(BaseAgent):