pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +832 -530
  4. pyRDDLGym_jax/core/planner.py +422 -474
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +390 -584
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -29
  35. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -31,17 +31,18 @@ from pyRDDLGym.core.policy import BaseAgent
31
31
  from pyRDDLGym_jax import __version__
32
32
  from pyRDDLGym_jax.core import logic
33
33
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
34
- from pyRDDLGym_jax.core.logic import FuzzyLogic
34
+ from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
35
35
 
36
- # try to import matplotlib, if failed then skip plotting
36
+ # try to load the dash board
37
37
  try:
38
- import matplotlib.pyplot as plt
38
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
39
39
  except Exception:
40
- raise_warning('failed to import matplotlib: '
41
- 'plotting functionality will be disabled.', 'red')
40
+ raise_warning('Failed to load the dashboard visualization tool: '
41
+ 'please make sure you have installed the required packages.',
42
+ 'red')
42
43
  traceback.print_exc()
43
- plt = None
44
-
44
+ JaxPlannerDashboard = None
45
+
45
46
  Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
47
  Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
48
  Kwargs = Dict[str, Any]
@@ -57,6 +58,7 @@ Pytree = Any
57
58
  #
58
59
  # ***********************************************************************
59
60
 
61
+
60
62
  def _parse_config_file(path: str):
61
63
  if not os.path.isfile(path):
62
64
  raise FileNotFoundError(f'File {path} does not exist.')
@@ -95,18 +97,25 @@ def _load_config(config, args):
95
97
  # read the model settings
96
98
  logic_name = model_args.get('logic', 'FuzzyLogic')
97
99
  logic_kwargs = model_args.get('logic_kwargs', {})
98
- tnorm_name = model_args.get('tnorm', 'ProductTNorm')
99
- tnorm_kwargs = model_args.get('tnorm_kwargs', {})
100
- comp_name = model_args.get('complement', 'StandardComplement')
101
- comp_kwargs = model_args.get('complement_kwargs', {})
102
- compare_name = model_args.get('comparison', 'SigmoidComparison')
103
- compare_kwargs = model_args.get('comparison_kwargs', {})
104
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
- sampling_kwargs = model_args.get('sampling_kwargs', {})
106
- logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
107
- logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
108
- logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
109
- logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
100
+ if logic_name == 'FuzzyLogic':
101
+ tnorm_name = model_args.get('tnorm', 'ProductTNorm')
102
+ tnorm_kwargs = model_args.get('tnorm_kwargs', {})
103
+ comp_name = model_args.get('complement', 'StandardComplement')
104
+ comp_kwargs = model_args.get('complement_kwargs', {})
105
+ compare_name = model_args.get('comparison', 'SigmoidComparison')
106
+ compare_kwargs = model_args.get('comparison_kwargs', {})
107
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
108
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
109
+ rounding_name = model_args.get('rounding', 'SoftRounding')
110
+ rounding_kwargs = model_args.get('rounding_kwargs', {})
111
+ control_name = model_args.get('control', 'SoftControlFlow')
112
+ control_kwargs = model_args.get('control_kwargs', {})
113
+ logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
114
+ logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
115
+ logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
116
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
117
+ logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
118
+ logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
110
119
 
111
120
  # read the policy settings
112
121
  plan_method = planner_args.pop('method')
@@ -157,11 +166,25 @@ def _load_config(config, args):
157
166
  else:
158
167
  planner_args['optimizer'] = optimizer
159
168
 
160
- # read the optimize call settings
169
+ # optimize call RNG key
161
170
  planner_key = train_args.get('key', None)
162
171
  if planner_key is not None:
163
172
  train_args['key'] = random.PRNGKey(planner_key)
164
173
 
174
+ # dashboard
175
+ dashboard_key = train_args.get('dashboard', None)
176
+ if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
177
+ train_args['dashboard'] = JaxPlannerDashboard()
178
+ elif dashboard_key is not None:
179
+ del train_args['dashboard']
180
+
181
+ # optimize call stopping rule
182
+ stopping_rule = train_args.get('stopping_rule', None)
183
+ if stopping_rule is not None:
184
+ stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
185
+ train_args['stopping_rule'] = getattr(
186
+ sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
187
+
165
188
  return planner_args, plan_kwargs, train_args
166
189
 
167
190
 
@@ -175,7 +198,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
175
198
  '''Loads config file contents specified explicitly as a string value.'''
176
199
  config, args = _parse_config_string(value)
177
200
  return _load_config(config, args)
178
-
201
+
179
202
 
180
203
  # ***********************************************************************
181
204
  # MODEL RELAXATIONS
@@ -193,7 +216,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
193
216
  '''
194
217
 
195
218
  def __init__(self, *args,
196
- logic: FuzzyLogic=FuzzyLogic(),
219
+ logic: Logic=FuzzyLogic(),
197
220
  cpfs_without_grad: Optional[Set[str]]=None,
198
221
  **kwargs) -> None:
199
222
  '''Creates a new RDDL to Jax compiler, where operations that are not
@@ -228,57 +251,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
228
251
 
229
252
  # overwrite basic operations with fuzzy ones
230
253
  self.RELATIONAL_OPS = {
231
- '>=': logic.greater_equal(),
232
- '<=': logic.less_equal(),
233
- '<': logic.less(),
234
- '>': logic.greater(),
235
- '==': logic.equal(),
236
- '~=': logic.not_equal()
254
+ '>=': logic.greater_equal,
255
+ '<=': logic.less_equal,
256
+ '<': logic.less,
257
+ '>': logic.greater,
258
+ '==': logic.equal,
259
+ '~=': logic.not_equal
237
260
  }
238
- self.LOGICAL_NOT = logic.logical_not()
261
+ self.LOGICAL_NOT = logic.logical_not
239
262
  self.LOGICAL_OPS = {
240
- '^': logic.logical_and(),
241
- '&': logic.logical_and(),
242
- '|': logic.logical_or(),
243
- '~': logic.xor(),
244
- '=>': logic.implies(),
245
- '<=>': logic.equiv()
263
+ '^': logic.logical_and,
264
+ '&': logic.logical_and,
265
+ '|': logic.logical_or,
266
+ '~': logic.xor,
267
+ '=>': logic.implies,
268
+ '<=>': logic.equiv
246
269
  }
247
- self.AGGREGATION_OPS['forall'] = logic.forall()
248
- self.AGGREGATION_OPS['exists'] = logic.exists()
249
- self.AGGREGATION_OPS['argmin'] = logic.argmin()
250
- self.AGGREGATION_OPS['argmax'] = logic.argmax()
251
- self.KNOWN_UNARY['sgn'] = logic.sgn()
252
- self.KNOWN_UNARY['floor'] = logic.floor()
253
- self.KNOWN_UNARY['ceil'] = logic.ceil()
254
- self.KNOWN_UNARY['round'] = logic.round()
255
- self.KNOWN_UNARY['sqrt'] = logic.sqrt()
256
- self.KNOWN_BINARY['div'] = logic.div()
257
- self.KNOWN_BINARY['mod'] = logic.mod()
258
- self.KNOWN_BINARY['fmod'] = logic.mod()
259
- self.IF_HELPER = logic.control_if()
260
- self.SWITCH_HELPER = logic.control_switch()
261
- self.BERNOULLI_HELPER = logic.bernoulli()
262
- self.DISCRETE_HELPER = logic.discrete()
263
- self.POISSON_HELPER = logic.poisson()
264
- self.GEOMETRIC_HELPER = logic.geometric()
265
-
266
- def _jax_stop_grad(self, jax_expr):
267
-
270
+ self.AGGREGATION_OPS['forall'] = logic.forall
271
+ self.AGGREGATION_OPS['exists'] = logic.exists
272
+ self.AGGREGATION_OPS['argmin'] = logic.argmin
273
+ self.AGGREGATION_OPS['argmax'] = logic.argmax
274
+ self.KNOWN_UNARY['sgn'] = logic.sgn
275
+ self.KNOWN_UNARY['floor'] = logic.floor
276
+ self.KNOWN_UNARY['ceil'] = logic.ceil
277
+ self.KNOWN_UNARY['round'] = logic.round
278
+ self.KNOWN_UNARY['sqrt'] = logic.sqrt
279
+ self.KNOWN_BINARY['div'] = logic.div
280
+ self.KNOWN_BINARY['mod'] = logic.mod
281
+ self.KNOWN_BINARY['fmod'] = logic.mod
282
+ self.IF_HELPER = logic.control_if
283
+ self.SWITCH_HELPER = logic.control_switch
284
+ self.BERNOULLI_HELPER = logic.bernoulli
285
+ self.DISCRETE_HELPER = logic.discrete
286
+ self.POISSON_HELPER = logic.poisson
287
+ self.GEOMETRIC_HELPER = logic.geometric
288
+
289
+ def _jax_stop_grad(self, jax_expr):
268
290
  def _jax_wrapped_stop_grad(x, params, key):
269
- sample, key, error = jax_expr(x, params, key)
291
+ sample, key, error, params = jax_expr(x, params, key)
270
292
  sample = jax.lax.stop_gradient(sample)
271
- return sample, key, error
272
-
293
+ return sample, key, error, params
273
294
  return _jax_wrapped_stop_grad
274
295
 
275
- def _compile_cpfs(self, info):
296
+ def _compile_cpfs(self, init_params):
276
297
  cpfs_cast = set()
277
298
  jax_cpfs = {}
278
299
  for (_, cpfs) in self.levels.items():
279
300
  for cpf in cpfs:
280
301
  _, expr = self.rddl.cpfs[cpf]
281
- jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
302
+ jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
282
303
  if self.rddl.variable_ranges[cpf] != 'real':
283
304
  cpfs_cast.add(cpf)
284
305
  if cpf in self.cpfs_without_grad:
@@ -289,17 +310,15 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
289
310
  f'{cpfs_cast} be cast to float.')
290
311
  if self.cpfs_without_grad:
291
312
  raise_warning(f'User requested that gradients not flow '
292
- f'through CPFs {self.cpfs_without_grad}.')
313
+ f'through CPFs {self.cpfs_without_grad}.')
314
+
293
315
  return jax_cpfs
294
316
 
295
- def _jax_kron(self, expr, info):
296
- if self.logic.verbose:
297
- raise_warning('JAX gradient compiler ignores KronDelta '
298
- 'during compilation.')
317
+ def _jax_kron(self, expr, init_params):
299
318
  arg, = expr.args
300
- arg = self._jax(arg, info)
319
+ arg = self._jax(arg, init_params)
301
320
  return arg
302
-
321
+
303
322
 
304
323
  # ***********************************************************************
305
324
  # ALL VERSIONS OF JAX PLANS
@@ -309,6 +328,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
309
328
  #
310
329
  # ***********************************************************************
311
330
 
331
+
312
332
  class JaxPlan:
313
333
  '''Base class for all JAX policy representations.'''
314
334
 
@@ -320,7 +340,7 @@ class JaxPlan:
320
340
  self.bounds = None
321
341
 
322
342
  def summarize_hyperparameters(self) -> None:
323
- pass
343
+ print(self.__str__())
324
344
 
325
345
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
326
346
  _bounds: Bounds,
@@ -363,7 +383,7 @@ class JaxPlan:
363
383
  self._projection = value
364
384
 
365
385
  def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
366
- user_bounds: Bounds,
386
+ user_bounds: Bounds,
367
387
  horizon: int):
368
388
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
369
389
  for (name, prange) in compiled.rddl.variable_ranges.items():
@@ -446,24 +466,24 @@ class JaxStraightLinePlan(JaxPlan):
446
466
  self._wrap_softmax = wrap_softmax
447
467
  self._use_new_projection = use_new_projection
448
468
  self._max_constraint_iter = max_constraint_iter
449
-
450
- def summarize_hyperparameters(self) -> None:
469
+
470
+ def __str__(self) -> str:
451
471
  bounds = '\n '.join(
452
472
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
453
- print(f'policy hyper-parameters:\n'
454
- f' initializer ={self._initializer_base}\n'
455
- f'constraint-sat strategy (simple):\n'
456
- f' parsed_action_bounds =\n {bounds}\n'
457
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
458
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
459
- f' wrap_non_bool ={self._wrap_non_bool}\n'
460
- f'constraint-sat strategy (complex):\n'
461
- f' wrap_softmax ={self._wrap_softmax}\n'
462
- f' use_new_projection ={self._use_new_projection}\n'
463
- f' max_projection_iters ={self._max_constraint_iter}')
473
+ return (f'policy hyper-parameters:\n'
474
+ f' initializer ={self._initializer_base}\n'
475
+ f'constraint-sat strategy (simple):\n'
476
+ f' parsed_action_bounds =\n {bounds}\n'
477
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
478
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
479
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
480
+ f'constraint-sat strategy (complex):\n'
481
+ f' wrap_softmax ={self._wrap_softmax}\n'
482
+ f' use_new_projection ={self._use_new_projection}\n'
483
+ f' max_projection_iters ={self._max_constraint_iter}')
464
484
 
465
485
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
466
- _bounds: Bounds,
486
+ _bounds: Bounds,
467
487
  horizon: int) -> None:
468
488
  rddl = compiled.rddl
469
489
 
@@ -504,7 +524,7 @@ class JaxStraightLinePlan(JaxPlan):
504
524
  def _jax_bool_action_to_param(var, action, hyperparams):
505
525
  if wrap_sigmoid:
506
526
  weight = hyperparams[var]
507
- return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
527
+ return jax.scipy.special.logit(action) / weight
508
528
  else:
509
529
  return action
510
530
 
@@ -513,14 +533,13 @@ class JaxStraightLinePlan(JaxPlan):
513
533
  def _jax_non_bool_param_to_action(var, param, hyperparams):
514
534
  if wrap_non_bool:
515
535
  lower, upper = bounds_safe[var]
516
- action = jnp.select(
517
- condlist=cond_lists[var],
518
- choicelist=[
519
- lower + (upper - lower) * jax.nn.sigmoid(param),
520
- lower + (jax.nn.elu(param) + 1.0),
521
- upper - (jax.nn.elu(-param) + 1.0),
522
- param
523
- ]
536
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
537
+ for mask in cond_lists[var]]
538
+ action = (
539
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
540
+ ml * (lower + (jax.nn.elu(param) + 1.0)) +
541
+ mu * (upper - (jax.nn.elu(-param) + 1.0)) +
542
+ mn * param
524
543
  )
525
544
  else:
526
545
  action = param
@@ -780,7 +799,7 @@ class JaxDeepReactivePolicy(JaxPlan):
780
799
  def __init__(self, topology: Optional[Sequence[int]]=None,
781
800
  activation: Activation=jnp.tanh,
782
801
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
783
- normalize: bool=False,
802
+ normalize: bool=False,
784
803
  normalize_per_layer: bool=False,
785
804
  normalizer_kwargs: Optional[Kwargs]=None,
786
805
  wrap_non_bool: bool=False) -> None:
@@ -812,23 +831,23 @@ class JaxDeepReactivePolicy(JaxPlan):
812
831
  normalizer_kwargs = {'create_offset': True, 'create_scale': True}
813
832
  self._normalizer_kwargs = normalizer_kwargs
814
833
  self._wrap_non_bool = wrap_non_bool
815
-
816
- def summarize_hyperparameters(self) -> None:
834
+
835
+ def __str__(self) -> str:
817
836
  bounds = '\n '.join(
818
837
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
819
- print(f'policy hyper-parameters:\n'
820
- f' topology ={self._topology}\n'
821
- f' activation_fn ={self._activations[0].__name__}\n'
822
- f' initializer ={type(self._initializer_base).__name__}\n'
823
- f' apply_input_norm ={self._normalize}\n'
824
- f' input_norm_layerwise={self._normalize_per_layer}\n'
825
- f' input_norm_args ={self._normalizer_kwargs}\n'
826
- f'constraint-sat strategy:\n'
827
- f' parsed_action_bounds=\n {bounds}\n'
828
- f' wrap_non_bool ={self._wrap_non_bool}')
829
-
838
+ return (f'policy hyper-parameters:\n'
839
+ f' topology ={self._topology}\n'
840
+ f' activation_fn ={self._activations[0].__name__}\n'
841
+ f' initializer ={type(self._initializer_base).__name__}\n'
842
+ f' apply_input_norm ={self._normalize}\n'
843
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
844
+ f' input_norm_args ={self._normalizer_kwargs}\n'
845
+ f'constraint-sat strategy:\n'
846
+ f' parsed_action_bounds=\n {bounds}\n'
847
+ f' wrap_non_bool ={self._wrap_non_bool}')
848
+
830
849
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
831
- _bounds: Bounds,
850
+ _bounds: Bounds,
832
851
  horizon: int) -> None:
833
852
  rddl = compiled.rddl
834
853
 
@@ -881,7 +900,7 @@ class JaxDeepReactivePolicy(JaxPlan):
881
900
  if normalize_per_layer and value_size == 1:
882
901
  raise_warning(
883
902
  f'Cannot apply layer norm to state-fluent <{var}> '
884
- f'of size 1: setting normalize_per_layer = False.',
903
+ f'of size 1: setting normalize_per_layer = False.',
885
904
  'red')
886
905
  normalize_per_layer = False
887
906
  non_bool_dims += value_size
@@ -906,8 +925,8 @@ class JaxDeepReactivePolicy(JaxPlan):
906
925
  else:
907
926
  if normalize and normalize_per_layer:
908
927
  normalizer = hk.LayerNorm(
909
- axis=-1, param_axis=-1,
910
- name=f'input_norm_{input_names[var]}',
928
+ axis=-1, param_axis=-1,
929
+ name=f'input_norm_{input_names[var]}',
911
930
  **self._normalizer_kwargs)
912
931
  state = normalizer(state)
913
932
  states_non_bool.append(state)
@@ -917,7 +936,7 @@ class JaxDeepReactivePolicy(JaxPlan):
917
936
  # optionally perform layer normalization on the non-bool inputs
918
937
  if normalize and not normalize_per_layer and non_bool_dims:
919
938
  normalizer = hk.LayerNorm(
920
- axis=-1, param_axis=-1, name='input_norm',
939
+ axis=-1, param_axis=-1, name='input_norm',
921
940
  **self._normalizer_kwargs)
922
941
  normalized = normalizer(state[:non_bool_dims])
923
942
  state = state.at[:non_bool_dims].set(normalized)
@@ -950,14 +969,13 @@ class JaxDeepReactivePolicy(JaxPlan):
950
969
  else:
951
970
  if wrap_non_bool:
952
971
  lower, upper = bounds_safe[var]
953
- action = jnp.select(
954
- condlist=cond_lists[var],
955
- choicelist=[
956
- lower + (upper - lower) * jax.nn.sigmoid(output),
957
- lower + (jax.nn.elu(output) + 1.0),
958
- upper - (jax.nn.elu(-output) + 1.0),
959
- output
960
- ]
972
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
973
+ for mask in cond_lists[var]]
974
+ action = (
975
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
976
+ ml * (lower + (jax.nn.elu(output) + 1.0)) +
977
+ mu * (upper - (jax.nn.elu(-output) + 1.0)) +
978
+ mn * output
961
979
  )
962
980
  else:
963
981
  action = output
@@ -1049,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1049
1067
 
1050
1068
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1051
1069
  return params
1052
-
1070
+
1053
1071
 
1054
1072
  # ***********************************************************************
1055
1073
  # ALL VERSIONS OF JAX PLANNER
@@ -1059,6 +1077,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1059
1077
  #
1060
1078
  # ***********************************************************************
1061
1079
 
1080
+
1062
1081
  class RollingMean:
1063
1082
  '''Maintains an estimate of the rolling mean of a stream of real-valued
1064
1083
  observations.'''
@@ -1077,113 +1096,6 @@ class RollingMean:
1077
1096
  return self._total / len(memory)
1078
1097
 
1079
1098
 
1080
- class JaxPlannerPlot:
1081
- '''Supports plotting and visualization of a JAX policy in real time.'''
1082
-
1083
- def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1084
- show_violin: bool=True, show_action: bool=True) -> None:
1085
- '''Creates a new planner visualizer.
1086
-
1087
- :param rddl: the planning model to optimize
1088
- :param horizon: the lookahead or planning horizon
1089
- :param show_violin: whether to show the distribution of batch losses
1090
- :param show_action: whether to show heatmaps of the action fluents
1091
- '''
1092
- num_plots = 1
1093
- if show_violin:
1094
- num_plots += 1
1095
- if show_action:
1096
- num_plots += len(rddl.action_fluents)
1097
- self._fig, axes = plt.subplots(num_plots)
1098
- if num_plots == 1:
1099
- axes = [axes]
1100
-
1101
- # prepare the loss plot
1102
- self._loss_ax = axes[0]
1103
- self._loss_ax.autoscale(enable=True)
1104
- self._loss_ax.set_xlabel('training time')
1105
- self._loss_ax.set_ylabel('loss value')
1106
- self._loss_plot = self._loss_ax.plot(
1107
- [], [], linestyle=':', marker='o', markersize=2)[0]
1108
- self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1109
-
1110
- # prepare the violin plot
1111
- if show_violin:
1112
- self._hist_ax = axes[1]
1113
- else:
1114
- self._hist_ax = None
1115
-
1116
- # prepare the action plots
1117
- if show_action:
1118
- self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1119
- for (idx, name) in enumerate(rddl.action_fluents)}
1120
- self._action_plots = {}
1121
- for name in rddl.action_fluents:
1122
- ax = self._action_ax[name]
1123
- if rddl.variable_ranges[name] == 'bool':
1124
- vmin, vmax = 0.0, 1.0
1125
- else:
1126
- vmin, vmax = None, None
1127
- action_dim = 1
1128
- for dim in rddl.object_counts(rddl.variable_params[name]):
1129
- action_dim *= dim
1130
- action_plot = ax.pcolormesh(
1131
- np.zeros((action_dim, horizon)),
1132
- cmap='seismic', vmin=vmin, vmax=vmax)
1133
- ax.set_aspect('auto')
1134
- ax.set_xlabel('decision epoch')
1135
- ax.set_ylabel(name)
1136
- plt.colorbar(action_plot, ax=ax)
1137
- self._action_plots[name] = action_plot
1138
- self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1139
- for (name, ax) in self._action_ax.items()}
1140
- else:
1141
- self._action_ax = None
1142
- self._action_plots = None
1143
- self._action_back = None
1144
-
1145
- plt.tight_layout()
1146
- plt.show(block=False)
1147
-
1148
- def redraw(self, xticks, losses, actions, returns) -> None:
1149
-
1150
- # draw the loss curve
1151
- self._fig.canvas.restore_region(self._loss_back)
1152
- self._loss_plot.set_xdata(xticks)
1153
- self._loss_plot.set_ydata(losses)
1154
- self._loss_ax.set_xlim([0, len(xticks)])
1155
- self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
1156
- self._loss_ax.draw_artist(self._loss_plot)
1157
- self._fig.canvas.blit(self._loss_ax.bbox)
1158
-
1159
- # draw the violin plot
1160
- if self._hist_ax is not None:
1161
- self._hist_ax.clear()
1162
- self._hist_ax.set_xlabel('loss value')
1163
- self._hist_ax.set_ylabel('density')
1164
- self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1165
-
1166
- # draw the actions
1167
- if self._action_ax is not None:
1168
- for (name, values) in actions.items():
1169
- values = np.mean(values, axis=0, dtype=float)
1170
- values = np.reshape(values, newshape=(values.shape[0], -1)).T
1171
- self._fig.canvas.restore_region(self._action_back[name])
1172
- self._action_plots[name].set_array(values)
1173
- self._action_ax[name].draw_artist(self._action_plots[name])
1174
- self._fig.canvas.blit(self._action_ax[name].bbox)
1175
- self._action_plots[name].set_clim([np.min(values), np.max(values)])
1176
-
1177
- self._fig.canvas.draw()
1178
- self._fig.canvas.flush_events()
1179
-
1180
- def close(self) -> None:
1181
- plt.close(self._fig)
1182
- del self._loss_ax, self._hist_ax, self._action_ax, \
1183
- self._loss_plot, self._action_plots, self._fig, \
1184
- self._loss_back, self._action_back
1185
-
1186
-
1187
1099
  class JaxPlannerStatus(Enum):
1188
1100
  '''Represents the status of a policy update from the JAX planner,
1189
1101
  including whether the update resulted in nan gradient,
@@ -1191,16 +1103,50 @@ class JaxPlannerStatus(Enum):
1191
1103
  can be used to monitor and act based on the planner's progress.'''
1192
1104
 
1193
1105
  NORMAL = 0
1194
- NO_PROGRESS = 1
1195
- PRECONDITION_POSSIBLY_UNSATISFIED = 2
1196
- INVALID_GRADIENT = 3
1197
- TIME_BUDGET_REACHED = 4
1198
- ITER_BUDGET_REACHED = 5
1106
+ STOPPING_RULE_REACHED = 1
1107
+ NO_PROGRESS = 2
1108
+ PRECONDITION_POSSIBLY_UNSATISFIED = 3
1109
+ INVALID_GRADIENT = 4
1110
+ TIME_BUDGET_REACHED = 5
1111
+ ITER_BUDGET_REACHED = 6
1199
1112
 
1200
- def is_failure(self) -> bool:
1201
- return self.value >= 3
1113
+ def is_terminal(self) -> bool:
1114
+ return self.value == 1 or self.value >= 4
1202
1115
 
1203
1116
 
1117
+ class JaxPlannerStoppingRule:
1118
+ '''The base class of all planner stopping rules.'''
1119
+
1120
+ def reset(self) -> None:
1121
+ raise NotImplementedError
1122
+
1123
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1124
+ raise NotImplementedError
1125
+
1126
+
1127
+ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1128
+ '''Stopping rule based on no improvement for a fixed number of iterations.'''
1129
+
1130
+ def __init__(self, patience: int) -> None:
1131
+ self.patience = patience
1132
+
1133
+ def reset(self) -> None:
1134
+ self.callback = None
1135
+ self.iters_since_last_update = 0
1136
+
1137
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1138
+ if self.callback is None \
1139
+ or callback['best_return'] > self.callback['best_return']:
1140
+ self.callback = callback
1141
+ self.iters_since_last_update = 0
1142
+ else:
1143
+ self.iters_since_last_update += 1
1144
+ return self.iters_since_last_update >= self.patience
1145
+
1146
+ def __str__(self) -> str:
1147
+ return f'No improvement for {self.patience} iterations'
1148
+
1149
+
1204
1150
  class JaxBackpropPlanner:
1205
1151
  '''A class for optimizing an action sequence in the given RDDL MDP using
1206
1152
  gradient descent.'''
@@ -1215,13 +1161,16 @@ class JaxBackpropPlanner:
1215
1161
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1216
1162
  optimizer_kwargs: Optional[Kwargs]=None,
1217
1163
  clip_grad: Optional[float]=None,
1218
- logic: FuzzyLogic=FuzzyLogic(),
1164
+ line_search_kwargs: Optional[Kwargs]=None,
1165
+ noise_kwargs: Optional[Kwargs]=None,
1166
+ logic: Logic=FuzzyLogic(),
1219
1167
  use_symlog_reward: bool=False,
1220
1168
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1221
1169
  utility_kwargs: Optional[Kwargs]=None,
1222
1170
  cpfs_without_grad: Optional[Set[str]]=None,
1223
1171
  compile_non_fluent_exact: bool=True,
1224
- logger: Optional[Logger]=None) -> None:
1172
+ logger: Optional[Logger]=None,
1173
+ dashboard_viz: Optional[Any]=None) -> None:
1225
1174
  '''Creates a new gradient-based algorithm for optimizing action sequences
1226
1175
  (plan) in the given RDDL. Some operations will be converted to their
1227
1176
  differentiable counterparts; the specific operations can be customized
@@ -1241,7 +1190,10 @@ class JaxBackpropPlanner:
1241
1190
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
1242
1191
  factory (e.g. which parameters are controllable externally)
1243
1192
  :param clip_grad: maximum magnitude of gradient updates
1244
- :param logic: a subclass of FuzzyLogic for mapping exact mathematical
1193
+ :param line_search_kwargs: parameters to pass to optional line search
1194
+ method to scale learning rate
1195
+ :param noise_kwargs: parameters of optional gradient noise
1196
+ :param logic: a subclass of Logic for mapping exact mathematical
1245
1197
  operations to their differentiable counterparts
1246
1198
  :param use_symlog_reward: whether to use the symlog transform on the
1247
1199
  reward as a form of normalization
@@ -1256,6 +1208,8 @@ class JaxBackpropPlanner:
1256
1208
  :param compile_non_fluent_exact: whether non-fluent expressions
1257
1209
  are always compiled using exact JAX expressions
1258
1210
  :param logger: to log information about compilation to file
1211
+ :param dashboard_viz: optional visualizer object from the environment
1212
+ to pass to the dashboard to visualize the policy
1259
1213
  '''
1260
1214
  self.rddl = rddl
1261
1215
  self.plan = plan
@@ -1270,29 +1224,34 @@ class JaxBackpropPlanner:
1270
1224
  action_bounds = {}
1271
1225
  self._action_bounds = action_bounds
1272
1226
  self.use64bit = use64bit
1273
- self._optimizer_name = optimizer
1227
+ self.optimizer_name = optimizer
1274
1228
  if optimizer_kwargs is None:
1275
1229
  optimizer_kwargs = {'learning_rate': 0.1}
1276
- self._optimizer_kwargs = optimizer_kwargs
1230
+ self.optimizer_kwargs = optimizer_kwargs
1277
1231
  self.clip_grad = clip_grad
1232
+ self.line_search_kwargs = line_search_kwargs
1233
+ self.noise_kwargs = noise_kwargs
1278
1234
 
1279
1235
  # set optimizer
1280
1236
  try:
1281
1237
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1282
1238
  except Exception as _:
1283
1239
  raise_warning(
1284
- 'Failed to inject hyperparameters into optax optimizer, '
1240
+ f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
1285
1241
  'rolling back to safer method: please note that modification of '
1286
- 'optimizer hyperparameters will not work, and it is '
1287
- 'recommended to update optax and related packages.', 'red')
1288
- optimizer = optimizer(**optimizer_kwargs)
1289
- if clip_grad is None:
1290
- self.optimizer = optimizer
1291
- else:
1292
- self.optimizer = optax.chain(
1293
- optax.clip(clip_grad),
1294
- optimizer
1295
- )
1242
+ 'optimizer hyperparameters will not work.', 'red')
1243
+ optimizer = optimizer(**optimizer_kwargs)
1244
+
1245
+ # apply optimizer chain of transformations
1246
+ pipeline = []
1247
+ if clip_grad is not None:
1248
+ pipeline.append(optax.clip(clip_grad))
1249
+ if noise_kwargs is not None:
1250
+ pipeline.append(optax.add_noise(**noise_kwargs))
1251
+ pipeline.append(optimizer)
1252
+ if line_search_kwargs is not None:
1253
+ pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
1254
+ self.optimizer = optax.chain(*pipeline)
1296
1255
 
1297
1256
  # set utility
1298
1257
  if isinstance(utility, str):
@@ -1325,11 +1284,12 @@ class JaxBackpropPlanner:
1325
1284
  self.cpfs_without_grad = cpfs_without_grad
1326
1285
  self.compile_non_fluent_exact = compile_non_fluent_exact
1327
1286
  self.logger = logger
1287
+ self.dashboard_viz = dashboard_viz
1328
1288
 
1329
1289
  self._jax_compile_rddl()
1330
1290
  self._jax_compile_optimizer()
1331
1291
 
1332
- def _summarize_system(self) -> None:
1292
+ def summarize_system(self) -> str:
1333
1293
  try:
1334
1294
  jaxlib_version = jax._src.lib.version_str
1335
1295
  except Exception as _:
@@ -1340,42 +1300,63 @@ class JaxBackpropPlanner:
1340
1300
  except Exception as _:
1341
1301
  devices_short = 'N/A'
1342
1302
  LOGO = \
1303
+ r"""
1304
+ __ ______ __ __ ______ __ ______ __ __
1305
+ /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1306
+ _\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
1307
+ /\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
1308
+ \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1343
1309
  """
1344
- __ ______ __ __ ______ __ ______ __ __
1345
- /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1346
- _\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
1347
- /\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
1348
- \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1349
- """
1350
-
1351
- print('\n'
1352
- f'{LOGO}\n'
1353
- f'Version {__version__}\n'
1354
- f'Python {sys.version}\n'
1355
- f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1356
- f'optax {optax.__version__}, haiku {hk.__version__}, '
1357
- f'numpy {np.__version__}\n'
1358
- f'devices: {devices_short}\n')
1310
+
1311
+ return ('\n'
1312
+ f'{LOGO}\n'
1313
+ f'Version {__version__}\n'
1314
+ f'Python {sys.version}\n'
1315
+ f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1316
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1317
+ f'numpy {np.__version__}\n'
1318
+ f'devices: {devices_short}\n')
1319
+
1320
+ def __str__(self) -> str:
1321
+ result = (f'objective hyper-parameters:\n'
1322
+ f' utility_fn ={self.utility.__name__}\n'
1323
+ f' utility args ={self.utility_kwargs}\n'
1324
+ f' use_symlog ={self.use_symlog_reward}\n'
1325
+ f' lookahead ={self.horizon}\n'
1326
+ f' user_action_bounds={self._action_bounds}\n'
1327
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1328
+ f' non_fluents exact ={self.compile_non_fluent_exact}\n'
1329
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1330
+ f'optimizer hyper-parameters:\n'
1331
+ f' use_64_bit ={self.use64bit}\n'
1332
+ f' optimizer ={self.optimizer_name}\n'
1333
+ f' optimizer args ={self.optimizer_kwargs}\n'
1334
+ f' clip_gradient ={self.clip_grad}\n'
1335
+ f' line_search_kwargs={self.line_search_kwargs}\n'
1336
+ f' noise_kwargs ={self.noise_kwargs}\n'
1337
+ f' batch_size_train ={self.batch_size_train}\n'
1338
+ f' batch_size_test ={self.batch_size_test}')
1339
+ result += '\n' + str(self.plan)
1340
+ result += '\n' + str(self.logic)
1341
+
1342
+ # print model relaxation information
1343
+ if not self.compiled.model_params:
1344
+ return result
1345
+ result += '\n' + ('Some RDDL operations are non-differentiable '
1346
+ 'and will be approximated as follows:' + '\n')
1347
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1348
+ for info in self.compiled.model_parameter_info().values():
1349
+ rddl_op = info['rddl_op']
1350
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1351
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1352
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1353
+ result += (f' {rddl_op}:\n'
1354
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1355
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1356
+ return result
1359
1357
 
1360
1358
  def summarize_hyperparameters(self) -> None:
1361
- print(f'objective hyper-parameters:\n'
1362
- f' utility_fn ={self.utility.__name__}\n'
1363
- f' utility args ={self.utility_kwargs}\n'
1364
- f' use_symlog ={self.use_symlog_reward}\n'
1365
- f' lookahead ={self.horizon}\n'
1366
- f' user_action_bounds={self._action_bounds}\n'
1367
- f' fuzzy logic type ={type(self.logic).__name__}\n'
1368
- f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1369
- f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1370
- f'optimizer hyper-parameters:\n'
1371
- f' use_64_bit ={self.use64bit}\n'
1372
- f' optimizer ={self._optimizer_name.__name__}\n'
1373
- f' optimizer args ={self._optimizer_kwargs}\n'
1374
- f' clip_gradient ={self.clip_grad}\n'
1375
- f' batch_size_train ={self.batch_size_train}\n'
1376
- f' batch_size_test ={self.batch_size_test}')
1377
- self.plan.summarize_hyperparameters()
1378
- self.logic.summarize_hyperparameters()
1359
+ print(self.__str__())
1379
1360
 
1380
1361
  # ===========================================================================
1381
1362
  # COMPILATION SUBROUTINES
@@ -1391,14 +1372,16 @@ class JaxBackpropPlanner:
1391
1372
  logger=self.logger,
1392
1373
  use64bit=self.use64bit,
1393
1374
  cpfs_without_grad=self.cpfs_without_grad,
1394
- compile_non_fluent_exact=self.compile_non_fluent_exact)
1375
+ compile_non_fluent_exact=self.compile_non_fluent_exact
1376
+ )
1395
1377
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1396
1378
 
1397
1379
  # Jax compilation of the exact RDDL for testing
1398
1380
  self.test_compiled = JaxRDDLCompiler(
1399
- rddl=rddl,
1381
+ rddl=rddl,
1400
1382
  logger=self.logger,
1401
- use64bit=self.use64bit)
1383
+ use64bit=self.use64bit
1384
+ )
1402
1385
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1403
1386
 
1404
1387
  def _jax_compile_optimizer(self):
@@ -1414,13 +1397,15 @@ class JaxBackpropPlanner:
1414
1397
  train_rollouts = self.compiled.compile_rollouts(
1415
1398
  policy=self.plan.train_policy,
1416
1399
  n_steps=self.horizon,
1417
- n_batch=self.batch_size_train)
1400
+ n_batch=self.batch_size_train
1401
+ )
1418
1402
  self.train_rollouts = train_rollouts
1419
1403
 
1420
1404
  test_rollouts = self.test_compiled.compile_rollouts(
1421
1405
  policy=self.plan.test_policy,
1422
1406
  n_steps=self.horizon,
1423
- n_batch=self.batch_size_test)
1407
+ n_batch=self.batch_size_test
1408
+ )
1424
1409
  self.test_rollouts = jax.jit(test_rollouts)
1425
1410
 
1426
1411
  # initialization
@@ -1455,14 +1440,16 @@ class JaxBackpropPlanner:
1455
1440
  _jax_wrapped_returns = self._jax_return(use_symlog)
1456
1441
 
1457
1442
  # the loss is the average cumulative reward across all roll-outs
1458
- def _jax_wrapped_plan_loss(key, policy_params, hyperparams,
1443
+ def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
1459
1444
  subs, model_params):
1460
- log = rollouts(key, policy_params, hyperparams, subs, model_params)
1445
+ log, model_params = rollouts(
1446
+ key, policy_params, policy_hyperparams, subs, model_params)
1461
1447
  rewards = log['reward']
1462
1448
  returns = _jax_wrapped_returns(rewards)
1463
1449
  utility = utility_fn(returns, **utility_kwargs)
1464
1450
  loss = -utility
1465
- return loss, log
1451
+ aux = (log, model_params)
1452
+ return loss, aux
1466
1453
 
1467
1454
  return _jax_wrapped_plan_loss
1468
1455
 
@@ -1470,30 +1457,45 @@ class JaxBackpropPlanner:
1470
1457
  init = self.plan.initializer
1471
1458
  optimizer = self.optimizer
1472
1459
 
1473
- def _jax_wrapped_init_policy(key, hyperparams, subs):
1474
- policy_params = init(key, hyperparams, subs)
1460
+ # initialize both the policy and its optimizer
1461
+ def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
1462
+ policy_params = init(key, policy_hyperparams, subs)
1475
1463
  opt_state = optimizer.init(policy_params)
1476
- return policy_params, opt_state, None
1464
+ return policy_params, opt_state, {}
1477
1465
 
1478
1466
  return _jax_wrapped_init_policy
1479
1467
 
1480
1468
  def _jax_update(self, loss):
1481
1469
  optimizer = self.optimizer
1482
1470
  projection = self.plan.projection
1471
+ use_ls = self.line_search_kwargs is not None
1483
1472
 
1484
1473
  # calculate the plan gradient w.r.t. return loss and update optimizer
1485
1474
  # also perform a projection step to satisfy constraints on actions
1486
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1475
+ def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
1476
+ subs, model_params):
1477
+ return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
1478
+
1479
+ def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
1487
1480
  subs, model_params, opt_state, opt_aux):
1488
1481
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1489
- (loss_val, log), grad = grad_fn(
1490
- key, policy_params, hyperparams, subs, model_params)
1491
- updates, opt_state = optimizer.update(grad, opt_state)
1482
+ (loss_val, (log, model_params)), grad = grad_fn(
1483
+ key, policy_params, policy_hyperparams, subs, model_params)
1484
+ if use_ls:
1485
+ updates, opt_state = optimizer.update(
1486
+ grad, opt_state, params=policy_params,
1487
+ value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
1488
+ key=key, policy_hyperparams=policy_hyperparams, subs=subs,
1489
+ model_params=model_params)
1490
+ else:
1491
+ updates, opt_state = optimizer.update(
1492
+ grad, opt_state, params=policy_params)
1492
1493
  policy_params = optax.apply_updates(policy_params, updates)
1493
- policy_params, converged = projection(policy_params, hyperparams)
1494
+ policy_params, converged = projection(policy_params, policy_hyperparams)
1494
1495
  log['grad'] = grad
1495
1496
  log['updates'] = updates
1496
- return policy_params, converged, opt_state, None, loss_val, log
1497
+ return policy_params, converged, opt_state, opt_aux, \
1498
+ loss_val, log, model_params
1497
1499
 
1498
1500
  return jax.jit(_jax_wrapped_plan_update)
1499
1501
 
@@ -1520,11 +1522,10 @@ class JaxBackpropPlanner:
1520
1522
  for (state, next_state) in rddl.next_state.items():
1521
1523
  init_train[next_state] = init_train[state]
1522
1524
  init_test[next_state] = init_test[state]
1523
-
1524
1525
  return init_train, init_test
1525
1526
 
1526
1527
  def as_optimization_problem(
1527
- self, key: Optional[random.PRNGKey]=None,
1528
+ self, key: Optional[random.PRNGKey]=None,
1528
1529
  policy_hyperparams: Optional[Pytree]=None,
1529
1530
  loss_function_updates_key: bool=True,
1530
1531
  grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
@@ -1574,38 +1575,40 @@ class JaxBackpropPlanner:
1574
1575
  loss_fn = self._jax_loss(self.train_rollouts)
1575
1576
 
1576
1577
  @jax.jit
1577
- def _loss_with_key(key, params_1d):
1578
+ def _loss_with_key(key, params_1d, model_params):
1578
1579
  policy_params = unravel_fn(params_1d)
1579
- loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1580
- train_subs, model_params)
1581
- return loss_val
1580
+ loss_val, (_, model_params) = loss_fn(
1581
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1582
+ return loss_val, model_params
1582
1583
 
1583
1584
  @jax.jit
1584
- def _grad_with_key(key, params_1d):
1585
+ def _grad_with_key(key, params_1d, model_params):
1585
1586
  policy_params = unravel_fn(params_1d)
1586
1587
  grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1587
- grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1588
- train_subs, model_params)
1589
- grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1590
- return grad_1d
1588
+ grad_val, (_, model_params) = grad_fn(
1589
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1590
+ grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
1591
+ return grad_val, model_params
1591
1592
 
1592
1593
  def _loss_function(params_1d):
1593
1594
  nonlocal key
1595
+ nonlocal model_params
1594
1596
  if loss_function_updates_key:
1595
1597
  key, subkey = random.split(key)
1596
1598
  else:
1597
1599
  subkey = key
1598
- loss_val = _loss_with_key(subkey, params_1d)
1600
+ loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
1599
1601
  loss_val = float(loss_val)
1600
1602
  return loss_val
1601
1603
 
1602
1604
  def _grad_function(params_1d):
1603
1605
  nonlocal key
1606
+ nonlocal model_params
1604
1607
  if grad_function_updates_key:
1605
1608
  key, subkey = random.split(key)
1606
1609
  else:
1607
1610
  subkey = key
1608
- grad = _grad_with_key(subkey, params_1d)
1611
+ grad, model_params = _grad_with_key(subkey, params_1d, model_params)
1609
1612
  grad = np.asarray(grad)
1610
1613
  return grad
1611
1614
 
@@ -1620,9 +1623,9 @@ class JaxBackpropPlanner:
1620
1623
 
1621
1624
  :param key: JAX PRNG key (derived from clock if not provided)
1622
1625
  :param epochs: the maximum number of steps of gradient descent
1623
- :param train_seconds: total time allocated for gradient descent
1624
- :param plot_step: frequency to plot the plan and save result to disk
1625
- :param plot_kwargs: additional arguments to pass to the plotter
1626
+ :param train_seconds: total time allocated for gradient descent
1627
+ :param dashboard: dashboard to display training results
1628
+ :param dashboard_id: experiment id for the dashboard
1626
1629
  :param model_params: optional model-parameters to override default
1627
1630
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1628
1631
  weights for sigmoid wrapping boolean actions
@@ -1633,6 +1636,7 @@ class JaxBackpropPlanner:
1633
1636
  :param print_summary: whether to print planner header, parameter
1634
1637
  summary, and diagnosis
1635
1638
  :param print_progress: whether to print the progress bar during training
1639
+ :param stopping_rule: stopping criterion
1636
1640
  :param test_rolling_window: the test return is averaged on a rolling
1637
1641
  window of the past test_rolling_window returns when updating the best
1638
1642
  parameters found so far
@@ -1657,14 +1661,15 @@ class JaxBackpropPlanner:
1657
1661
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1658
1662
  epochs: int=999999,
1659
1663
  train_seconds: float=120.,
1660
- plot_step: Optional[int]=None,
1661
- plot_kwargs: Optional[Dict[str, Any]]=None,
1664
+ dashboard: Optional[JaxPlannerDashboard]=None,
1665
+ dashboard_id: Optional[str]=None,
1662
1666
  model_params: Optional[Dict[str, Any]]=None,
1663
1667
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1664
1668
  subs: Optional[Dict[str, Any]]=None,
1665
1669
  guess: Optional[Pytree]=None,
1666
1670
  print_summary: bool=True,
1667
1671
  print_progress: bool=True,
1672
+ stopping_rule: Optional[JaxPlannerStoppingRule]=None,
1668
1673
  test_rolling_window: int=10,
1669
1674
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1670
1675
  '''Returns a generator for computing an optimal policy or plan.
@@ -1673,9 +1678,9 @@ class JaxBackpropPlanner:
1673
1678
 
1674
1679
  :param key: JAX PRNG key (derived from clock if not provided)
1675
1680
  :param epochs: the maximum number of steps of gradient descent
1676
- :param train_seconds: total time allocated for gradient descent
1677
- :param plot_step: frequency to plot the plan and save result to disk
1678
- :param plot_kwargs: additional arguments to pass to the plotter
1681
+ :param train_seconds: total time allocated for gradient descent
1682
+ :param dashboard: dashboard to display training results
1683
+ :param dashboard_id: experiment id for the dashboard
1679
1684
  :param model_params: optional model-parameters to override default
1680
1685
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1681
1686
  weights for sigmoid wrapping boolean actions
@@ -1686,6 +1691,7 @@ class JaxBackpropPlanner:
1686
1691
  :param print_summary: whether to print planner header, parameter
1687
1692
  summary, and diagnosis
1688
1693
  :param print_progress: whether to print the progress bar during training
1694
+ :param stopping_rule: stopping criterion
1689
1695
  :param test_rolling_window: the test return is averaged on a rolling
1690
1696
  window of the past test_rolling_window returns when updating the best
1691
1697
  parameters found so far
@@ -1694,9 +1700,14 @@ class JaxBackpropPlanner:
1694
1700
  start_time = time.time()
1695
1701
  elapsed_outside_loop = 0
1696
1702
 
1703
+ # ======================================================================
1704
+ # INITIALIZATION OF HYPER-PARAMETERS
1705
+ # ======================================================================
1706
+
1697
1707
  # if PRNG key is not provided
1698
1708
  if key is None:
1699
1709
  key = random.PRNGKey(round(time.time() * 1000))
1710
+ dash_key = key[1].item()
1700
1711
 
1701
1712
  # if policy_hyperparams is not provided
1702
1713
  if policy_hyperparams is None:
@@ -1723,7 +1734,7 @@ class JaxBackpropPlanner:
1723
1734
 
1724
1735
  # print summary of parameters:
1725
1736
  if print_summary:
1726
- self._summarize_system()
1737
+ print(self.summarize_system())
1727
1738
  self.summarize_hyperparameters()
1728
1739
  print(f'optimize() call hyper-parameters:\n'
1729
1740
  f' PRNG key ={key}\n'
@@ -1734,15 +1745,16 @@ class JaxBackpropPlanner:
1734
1745
  f' override_subs_dict ={subs is not None}\n'
1735
1746
  f' provide_param_guess={guess is not None}\n'
1736
1747
  f' test_rolling_window={test_rolling_window}\n'
1737
- f' plot_frequency ={plot_step}\n'
1738
- f' plot_kwargs ={plot_kwargs}\n'
1748
+ f' dashboard ={dashboard is not None}\n'
1749
+ f' dashboard_id ={dashboard_id}\n'
1739
1750
  f' print_summary ={print_summary}\n'
1740
- f' print_progress ={print_progress}\n')
1741
- if self.compiled.relaxations:
1742
- print('Some RDDL operations are non-differentiable, '
1743
- 'replacing them with differentiable relaxations:')
1744
- print(self.compiled.summarize_model_relaxations())
1745
-
1751
+ f' print_progress ={print_progress}\n'
1752
+ f' stopping_rule ={stopping_rule}\n')
1753
+
1754
+ # ======================================================================
1755
+ # INITIALIZATION OF STATE AND POLICY
1756
+ # ======================================================================
1757
+
1746
1758
  # compute a batched version of the initial values
1747
1759
  if subs is None:
1748
1760
  subs = self.test_compiled.init_values
@@ -1773,7 +1785,11 @@ class JaxBackpropPlanner:
1773
1785
  else:
1774
1786
  policy_params = guess
1775
1787
  opt_state = self.optimizer.init(policy_params)
1776
- opt_aux = None
1788
+ opt_aux = {}
1789
+
1790
+ # ======================================================================
1791
+ # INITIALIZATION OF RUNNING STATISTICS
1792
+ # ======================================================================
1777
1793
 
1778
1794
  # initialize running statistics
1779
1795
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
@@ -1783,30 +1799,43 @@ class JaxBackpropPlanner:
1783
1799
  status = JaxPlannerStatus.NORMAL
1784
1800
  is_all_zero_fn = lambda x: np.allclose(x, 0)
1785
1801
 
1786
- # initialize plot area
1787
- if plot_step is None or plot_step <= 0 or plt is None:
1788
- plot = None
1789
- else:
1790
- if plot_kwargs is None:
1791
- plot_kwargs = {}
1792
- plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1793
- xticks, loss_values = [], []
1802
+ # initialize stopping criterion
1803
+ if stopping_rule is not None:
1804
+ stopping_rule.reset()
1805
+
1806
+ # initialize dash board
1807
+ if dashboard is not None:
1808
+ dashboard_id = dashboard.register_experiment(
1809
+ dashboard_id, dashboard.get_planner_info(self),
1810
+ key=dash_key, viz=self.dashboard_viz)
1811
+
1812
+ # ======================================================================
1813
+ # MAIN TRAINING LOOP BEGINS
1814
+ # ======================================================================
1794
1815
 
1795
- # training loop
1796
1816
  iters = range(epochs)
1797
1817
  if print_progress:
1798
1818
  iters = tqdm(iters, total=100, position=tqdm_position)
1799
1819
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1800
1820
 
1801
1821
  for it in iters:
1822
+
1823
+ # ==================================================================
1824
+ # NEXT GRADIENT DESCENT STEP
1825
+ # ==================================================================
1826
+
1802
1827
  status = JaxPlannerStatus.NORMAL
1803
1828
 
1804
1829
  # update the parameters of the plan
1805
1830
  key, subkey = random.split(key)
1806
- policy_params, converged, opt_state, opt_aux, \
1807
- train_loss, train_log = \
1831
+ (policy_params, converged, opt_state, opt_aux,
1832
+ train_loss, train_log, model_params) = \
1808
1833
  self.update(subkey, policy_params, policy_hyperparams,
1809
1834
  train_subs, model_params, opt_state, opt_aux)
1835
+
1836
+ # ==================================================================
1837
+ # STATUS CHECKS AND LOGGING
1838
+ # ==================================================================
1810
1839
 
1811
1840
  # no progress
1812
1841
  grad_norm_zero, _ = jax.tree_util.tree_flatten(
@@ -1825,12 +1854,11 @@ class JaxBackpropPlanner:
1825
1854
  # numerical error
1826
1855
  if not np.isfinite(train_loss):
1827
1856
  raise_warning(
1828
- f'Aborting JAX planner due to invalid train loss {train_loss}.',
1829
- 'red')
1857
+ f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
1830
1858
  status = JaxPlannerStatus.INVALID_GRADIENT
1831
1859
 
1832
1860
  # evaluate test losses and record best plan so far
1833
- test_loss, log = self.test_loss(
1861
+ test_loss, (log, model_params_test) = self.test_loss(
1834
1862
  subkey, policy_params, policy_hyperparams,
1835
1863
  test_subs, model_params_test)
1836
1864
  test_loss = rolling_test_loss.update(test_loss)
@@ -1839,34 +1867,16 @@ class JaxBackpropPlanner:
1839
1867
  policy_params, test_loss, train_log['grad']
1840
1868
  last_iter_improve = it
1841
1869
 
1842
- # save the plan figure
1843
- if plot is not None and it % plot_step == 0:
1844
- xticks.append(it // plot_step)
1845
- loss_values.append(test_loss.item())
1846
- action_values = {name: values
1847
- for (name, values) in log['fluents'].items()
1848
- if name in self.rddl.action_fluents}
1849
- returns = -np.sum(np.asarray(log['reward']), axis=1)
1850
- plot.redraw(xticks, loss_values, action_values, returns)
1851
-
1852
- # if the progress bar is used
1853
- elapsed = time.time() - start_time - elapsed_outside_loop
1854
- if print_progress:
1855
- iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1856
- iters.set_description(
1857
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1858
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1859
- f'{status.value} status')
1860
-
1861
1870
  # reached computation budget
1871
+ elapsed = time.time() - start_time - elapsed_outside_loop
1862
1872
  if elapsed >= train_seconds:
1863
1873
  status = JaxPlannerStatus.TIME_BUDGET_REACHED
1864
1874
  if it >= epochs - 1:
1865
1875
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1866
1876
 
1867
- # return a callback
1868
- start_time_outside = time.time()
1869
- yield {
1877
+ # build a callback
1878
+ progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1879
+ callback = {
1870
1880
  'status': status,
1871
1881
  'iteration': it,
1872
1882
  'train_return':-train_loss,
@@ -1880,19 +1890,45 @@ class JaxBackpropPlanner:
1880
1890
  'updates': train_log['updates'],
1881
1891
  'elapsed_time': elapsed,
1882
1892
  'key': key,
1893
+ 'model_params': model_params,
1894
+ 'progress': progress_percent,
1895
+ 'train_log': train_log,
1883
1896
  **log
1884
1897
  }
1898
+
1899
+ # stopping condition reached
1900
+ if stopping_rule is not None and stopping_rule.monitor(callback):
1901
+ callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
1902
+
1903
+ # if the progress bar is used
1904
+ if print_progress:
1905
+ iters.n = progress_percent
1906
+ iters.set_description(
1907
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1908
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1909
+ f'{status.value} status'
1910
+ )
1911
+
1912
+ # dash-board
1913
+ if dashboard is not None:
1914
+ dashboard.update_experiment(dashboard_id, callback)
1915
+
1916
+ # yield the callback
1917
+ start_time_outside = time.time()
1918
+ yield callback
1885
1919
  elapsed_outside_loop += (time.time() - start_time_outside)
1886
1920
 
1887
1921
  # abortion check
1888
- if status.is_failure():
1889
- break
1890
-
1922
+ if status.is_terminal():
1923
+ break
1924
+
1925
+ # ======================================================================
1926
+ # POST-PROCESSING AND CLEANUP
1927
+ # ======================================================================
1928
+
1891
1929
  # release resources
1892
1930
  if print_progress:
1893
1931
  iters.close()
1894
- if plot is not None:
1895
- plot.close()
1896
1932
 
1897
1933
  # validate the test return
1898
1934
  if log:
@@ -1918,7 +1954,7 @@ class JaxBackpropPlanner:
1918
1954
  f' best_grad_norm={grad_norm}\n'
1919
1955
  f' diagnosis: {diagnosis}\n')
1920
1956
 
1921
- def _perform_diagnosis(self, last_iter_improve,
1957
+ def _perform_diagnosis(self, last_iter_improve,
1922
1958
  train_return, test_return, best_return, grad_norm):
1923
1959
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
1924
1960
  grad_is_zero = np.allclose(max_grad_norm, 0)
@@ -2016,101 +2052,6 @@ class JaxBackpropPlanner:
2016
2052
  return actions
2017
2053
 
2018
2054
 
2019
- class JaxLineSearchPlanner(JaxBackpropPlanner):
2020
- '''A class for optimizing an action sequence in the given RDDL MDP using
2021
- linear search gradient descent, with the Armijo condition.'''
2022
-
2023
- def __init__(self, *args,
2024
- decay: float=0.8,
2025
- c: float=0.1,
2026
- step_max: float=1.0,
2027
- step_min: float=1e-6,
2028
- **kwargs) -> None:
2029
- '''Creates a new gradient-based algorithm for optimizing action sequences
2030
- (plan) in the given RDDL using line search. All arguments are the
2031
- same as in the parent class, except:
2032
-
2033
- :param decay: reduction factor of learning rate per line search iteration
2034
- :param c: positive coefficient in Armijo condition, should be in (0, 1)
2035
- :param step_max: initial learning rate for line search
2036
- :param step_min: minimum possible learning rate (line search halts)
2037
- '''
2038
- self.decay = decay
2039
- self.c = c
2040
- self.step_max = step_max
2041
- self.step_min = step_min
2042
- if 'clip_grad' in kwargs:
2043
- raise_warning('clip_grad parameter conflicts with '
2044
- 'line search planner and will be ignored.', 'red')
2045
- del kwargs['clip_grad']
2046
- super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
2047
-
2048
- def summarize_hyperparameters(self) -> None:
2049
- super(JaxLineSearchPlanner, self).summarize_hyperparameters()
2050
- print(f'linesearch hyper-parameters:\n'
2051
- f' decay ={self.decay}\n'
2052
- f' c ={self.c}\n'
2053
- f' lr_range=({self.step_min}, {self.step_max})')
2054
-
2055
- def _jax_update(self, loss):
2056
- optimizer = self.optimizer
2057
- projection = self.plan.projection
2058
- decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
2059
-
2060
- # initialize the line search routine
2061
- @jax.jit
2062
- def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
2063
- subs, model_params):
2064
- (f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
2065
- key, policy_params, hyperparams, subs, model_params)
2066
- gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
2067
- gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
2068
- log['grad'] = grad
2069
- return f, grad, gnorm2, log
2070
-
2071
- # compute the next trial solution
2072
- @jax.jit
2073
- def _jax_wrapped_line_search_trial(
2074
- step, grad, key, params, hparams, subs, mparams, state):
2075
- state.hyperparams['learning_rate'] = step
2076
- updates, new_state = optimizer.update(grad, state)
2077
- new_params = optax.apply_updates(params, updates)
2078
- new_params, _ = projection(new_params, hparams)
2079
- f_step, _ = loss(key, new_params, hparams, subs, mparams)
2080
- return f_step, new_params, new_state
2081
-
2082
- # main iteration of line search
2083
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
2084
- subs, model_params, opt_state, opt_aux):
2085
-
2086
- # initialize the line search
2087
- f, grad, gnorm2, log = _jax_wrapped_line_search_init(
2088
- key, policy_params, hyperparams, subs, model_params)
2089
-
2090
- # continue to reduce the learning rate until the Armijo condition holds
2091
- trials = 0
2092
- step = lrmax / decay
2093
- f_step = np.inf
2094
- best_f, best_step, best_params, best_state = np.inf, None, None, None
2095
- while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2096
- or not trials:
2097
- trials += 1
2098
- step *= decay
2099
- f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2100
- step, grad, key, policy_params, hyperparams, subs,
2101
- model_params, opt_state)
2102
- if f_step < best_f:
2103
- best_f, best_step, best_params, best_state = \
2104
- f_step, step, new_params, new_state
2105
-
2106
- log['updates'] = None
2107
- log['line_search_iters'] = trials
2108
- log['learning_rate'] = best_step
2109
- return best_params, True, best_state, best_step, best_f, log
2110
-
2111
- return _jax_wrapped_plan_update
2112
-
2113
-
2114
2055
  # ***********************************************************************
2115
2056
  # ALL VERSIONS OF RISK FUNCTIONS
2116
2057
  #
@@ -2141,7 +2082,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2141
2082
  alpha_mask = jax.lax.stop_gradient(
2142
2083
  returns <= jnp.percentile(returns, q=100 * alpha))
2143
2084
  return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2144
-
2085
+
2145
2086
 
2146
2087
  # ***********************************************************************
2147
2088
  # ALL VERSIONS OF CONTROLLERS
@@ -2151,12 +2092,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2151
2092
  #
2152
2093
  # ***********************************************************************
2153
2094
 
2095
+
2154
2096
  class JaxOfflineController(BaseAgent):
2155
2097
  '''A container class for a Jax policy trained offline.'''
2156
2098
 
2157
2099
  use_tensor_obs = True
2158
2100
 
2159
- def __init__(self, planner: JaxBackpropPlanner,
2101
+ def __init__(self, planner: JaxBackpropPlanner,
2160
2102
  key: Optional[random.PRNGKey]=None,
2161
2103
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2162
2104
  params: Optional[Pytree]=None,
@@ -2186,8 +2128,10 @@ class JaxOfflineController(BaseAgent):
2186
2128
  self.params_given = params is not None
2187
2129
 
2188
2130
  self.step = 0
2131
+ self.callback = None
2189
2132
  if not self.train_on_reset and not self.params_given:
2190
2133
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2134
+ self.callback = callback
2191
2135
  params = callback['best_params']
2192
2136
  self.params = params
2193
2137
 
@@ -2202,6 +2146,7 @@ class JaxOfflineController(BaseAgent):
2202
2146
  self.step = 0
2203
2147
  if self.train_on_reset and not self.params_given:
2204
2148
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2149
+ self.callback = callback
2205
2150
  self.params = callback['best_params']
2206
2151
 
2207
2152
 
@@ -2211,7 +2156,7 @@ class JaxOnlineController(BaseAgent):
2211
2156
 
2212
2157
  use_tensor_obs = True
2213
2158
 
2214
- def __init__(self, planner: JaxBackpropPlanner,
2159
+ def __init__(self, planner: JaxBackpropPlanner,
2215
2160
  key: Optional[random.PRNGKey]=None,
2216
2161
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2217
2162
  warm_start: bool=True,
@@ -2244,7 +2189,9 @@ class JaxOnlineController(BaseAgent):
2244
2189
  key=self.key,
2245
2190
  guess=self.guess,
2246
2191
  subs=state,
2247
- **self.train_kwargs)
2192
+ **self.train_kwargs
2193
+ )
2194
+ self.callback = callback
2248
2195
  params = callback['best_params']
2249
2196
  self.key, subkey = random.split(self.key)
2250
2197
  actions = planner.get_action(
@@ -2255,4 +2202,5 @@ class JaxOnlineController(BaseAgent):
2255
2202
 
2256
2203
  def reset(self) -> None:
2257
2204
  self.guess = None
2205
+ self.callback = None
2258
2206