pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -2
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -43,8 +43,7 @@ import pickle
43
43
  import sys
44
44
  import time
45
45
  import traceback
46
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, \
47
- Union
46
+ from typing import Any, Callable, Dict, Generator, Optional, Sequence, Type, Tuple, Union
48
47
 
49
48
  import haiku as hk
50
49
  import jax
@@ -70,16 +69,18 @@ from pyRDDLGym.core.debug.exception import (
70
69
  from pyRDDLGym.core.policy import BaseAgent
71
70
 
72
71
  from pyRDDLGym_jax import __version__
73
- from pyRDDLGym_jax.core import logic
74
72
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
75
- from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
73
+ from pyRDDLGym_jax.core import logic
74
+ from pyRDDLGym_jax.core.logic import (
75
+ JaxRDDLCompilerWithGrad, DefaultJaxRDDLCompilerWithGrad, stable_sigmoid
76
+ )
76
77
 
77
78
  # try to load the dash board
78
79
  try:
79
80
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
80
81
  except Exception:
81
82
  raise_warning('Failed to load the dashboard visualization tool: '
82
- 'please make sure you have installed the required packages.', 'red')
83
+ 'ensure all prerequisite packages are installed.', 'red')
83
84
  traceback.print_exc()
84
85
  JaxPlannerDashboard = None
85
86
 
@@ -128,33 +129,15 @@ def _getattr_any(packages, item):
128
129
 
129
130
 
130
131
  def _load_config(config, args):
131
- model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
132
- planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
133
- train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
134
-
135
- # read the model settings
136
- logic_name = model_args.get('logic', 'FuzzyLogic')
137
- logic_kwargs = model_args.get('logic_kwargs', {})
138
- if logic_name == 'FuzzyLogic':
139
- tnorm_name = model_args.get('tnorm', 'ProductTNorm')
140
- tnorm_kwargs = model_args.get('tnorm_kwargs', {})
141
- comp_name = model_args.get('complement', 'StandardComplement')
142
- comp_kwargs = model_args.get('complement_kwargs', {})
143
- compare_name = model_args.get('comparison', 'SigmoidComparison')
144
- compare_kwargs = model_args.get('comparison_kwargs', {})
145
- sampling_name = model_args.get('sampling', 'SoftRandomSampling')
146
- sampling_kwargs = model_args.get('sampling_kwargs', {})
147
- rounding_name = model_args.get('rounding', 'SoftRounding')
148
- rounding_kwargs = model_args.get('rounding_kwargs', {})
149
- control_name = model_args.get('control', 'SoftControlFlow')
150
- control_kwargs = model_args.get('control_kwargs', {})
151
- logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
152
- logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
153
- logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
154
- logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
155
- logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
156
- logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
132
+ compiler_kwargs = {k: args['Compiler'][k] for (k, _) in config.items('Compiler')}
133
+ planner_args = {k: args['Planner'][k] for (k, _) in config.items('Planner')}
134
+ train_args = {k: args['Optimize'][k] for (k, _) in config.items('Optimize')}
157
135
 
136
+ # read the compiler settings
137
+ compiler_name = compiler_kwargs.pop('method', 'DefaultJaxRDDLCompilerWithGrad')
138
+ planner_args['compiler'] = getattr(logic, compiler_name)
139
+ planner_args['compiler_kwargs'] = compiler_kwargs
140
+
158
141
  # read the policy settings
159
142
  plan_method = planner_args.pop('method')
160
143
  plan_kwargs = planner_args.pop('method_kwargs', {})
@@ -183,7 +166,6 @@ def _load_config(config, args):
183
166
  plan_kwargs['activation'] = activation
184
167
 
185
168
  # read the planner settings
186
- planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
187
169
  planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
188
170
 
189
171
  # planner optimizer
@@ -220,11 +202,11 @@ def _load_config(config, args):
220
202
  train_args['key'] = random.PRNGKey(planner_key)
221
203
 
222
204
  # dashboard
223
- dashboard_key = train_args.get('dashboard', None)
205
+ dashboard_key = planner_args.get('dashboard', None)
224
206
  if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
225
- train_args['dashboard'] = JaxPlannerDashboard()
207
+ planner_args['dashboard'] = JaxPlannerDashboard()
226
208
  elif dashboard_key is not None:
227
- del train_args['dashboard']
209
+ del planner_args['dashboard']
228
210
 
229
211
  # optimize call stopping rule
230
212
  stopping_rule = train_args.get('stopping_rule', None)
@@ -253,102 +235,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
253
235
  config, args = _parse_config_string(value)
254
236
  return _load_config(config, args)
255
237
 
256
-
257
- # ***********************************************************************
258
- # MODEL RELAXATIONS
259
- #
260
- # - replace discrete ops in state dynamics/reward with differentiable ones
261
- #
262
- # ***********************************************************************
263
-
264
-
265
- class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
266
- '''Compiles a RDDL AST representation to an equivalent JAX representation.
267
- Unlike its parent class, this class treats all fluents as real-valued, and
268
- replaces all mathematical operations by equivalent ones with a well defined
269
- (e.g. non-zero) gradient where appropriate.
270
- '''
271
-
272
- def __init__(self, *args,
273
- logic: Logic=FuzzyLogic(),
274
- cpfs_without_grad: Optional[Set[str]]=None,
275
- print_warnings: bool=True,
276
- **kwargs) -> None:
277
- '''Creates a new RDDL to Jax compiler, where operations that are not
278
- differentiable are converted to approximate forms that have defined gradients.
279
-
280
- :param *args: arguments to pass to base compiler
281
- :param logic: Fuzzy logic object that specifies how exact operations
282
- are converted to their approximate forms: this class may be subclassed
283
- to customize these operations
284
- :param cpfs_without_grad: which CPFs do not have gradients (use straight
285
- through gradient trick)
286
- :param print_warnings: whether to print warnings
287
- :param *kwargs: keyword arguments to pass to base compiler
288
- '''
289
- super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
290
-
291
- self.logic = logic
292
- self.logic.set_use64bit(self.use64bit)
293
- if cpfs_without_grad is None:
294
- cpfs_without_grad = set()
295
- self.cpfs_without_grad = cpfs_without_grad
296
- self.print_warnings = print_warnings
297
-
298
- # actions and CPFs must be continuous
299
- pvars_cast = set()
300
- for (var, values) in self.init_values.items():
301
- self.init_values[var] = np.asarray(values, dtype=self.REAL)
302
- if not np.issubdtype(np.result_type(values), np.floating):
303
- pvars_cast.add(var)
304
- if self.print_warnings and pvars_cast:
305
- message = termcolor.colored(
306
- f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
307
- 'green')
308
- print(message)
309
-
310
- # overwrite basic operations with fuzzy ones
311
- self.OPS = logic.get_operator_dicts()
312
-
313
- def _jax_stop_grad(self, jax_expr):
314
- def _jax_wrapped_stop_grad(x, params, key):
315
- sample, key, error, params = jax_expr(x, params, key)
316
- sample = jax.lax.stop_gradient(sample)
317
- return sample, key, error, params
318
- return _jax_wrapped_stop_grad
319
-
320
- def _compile_cpfs(self, init_params):
321
-
322
- # cpfs will all be cast to float
323
- cpfs_cast = set()
324
- jax_cpfs = {}
325
- for (_, cpfs) in self.levels.items():
326
- for cpf in cpfs:
327
- _, expr = self.rddl.cpfs[cpf]
328
- jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
329
- if self.rddl.variable_ranges[cpf] != 'real':
330
- cpfs_cast.add(cpf)
331
- if cpf in self.cpfs_without_grad:
332
- jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
333
-
334
- if self.print_warnings and cpfs_cast:
335
- message = termcolor.colored(
336
- f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
337
- 'green')
338
- print(message)
339
- if self.print_warnings and self.cpfs_without_grad:
340
- message = termcolor.colored(
341
- f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
342
- 'green')
343
- print(message)
344
-
345
- return jax_cpfs
346
-
347
- def _jax_kron(self, expr, init_params):
348
- arg, = expr.args
349
- arg = self._jax(arg, init_params)
350
- return arg
351
-
352
238
 
353
239
  # ***********************************************************************
354
240
  # ALL VERSIONS OF STATE PREPROCESSING FOR DRP
@@ -369,15 +255,15 @@ class Preprocessor(metaclass=ABCMeta):
369
255
  self._transform = None
370
256
 
371
257
  @property
372
- def initialize(self):
258
+ def initialize(self) -> Callable:
373
259
  return self._initializer
374
260
 
375
261
  @property
376
- def update(self):
262
+ def update(self) -> Callable:
377
263
  return self._update
378
264
 
379
265
  @property
380
- def transform(self):
266
+ def transform(self) -> Callable:
381
267
  return self._transform
382
268
 
383
269
  @abstractmethod
@@ -421,26 +307,25 @@ class StaticNormalizer(Preprocessor):
421
307
  self._initializer = jax.jit(_jax_wrapped_normalizer_init)
422
308
 
423
309
  # static bounds
424
- def _jax_wrapped_normalizer_update(subs, stats):
425
- stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
426
- jnp.asarray(upper, dtype=compiled.REAL))
427
- for (var, (lower, upper)) in bounded_vars.items()}
428
- return stats
310
+ def _jax_wrapped_normalizer_update(fls, stats):
311
+ return {var: (jnp.asarray(lower, dtype=compiled.REAL),
312
+ jnp.asarray(upper, dtype=compiled.REAL))
313
+ for (var, (lower, upper)) in bounded_vars.items()}
429
314
  self._update = jax.jit(_jax_wrapped_normalizer_update)
430
315
 
431
316
  # apply min max scaling
432
- def _jax_wrapped_normalizer_transform(subs, stats):
433
- new_subs = {}
434
- for (var, values) in subs.items():
317
+ def _jax_wrapped_normalizer_transform(fls, stats):
318
+ new_fls = {}
319
+ for (var, values) in fls.items():
435
320
  if var in stats:
436
321
  lower, upper = stats[var]
437
322
  new_dims = jnp.ndim(values) - jnp.ndim(lower)
438
323
  lower = lower[(jnp.newaxis,) * new_dims + (...,)]
439
324
  upper = upper[(jnp.newaxis,) * new_dims + (...,)]
440
- new_subs[var] = (values - lower) / (upper - lower)
325
+ new_fls[var] = (values - lower) / (upper - lower)
441
326
  else:
442
- new_subs[var] = values
443
- return new_subs
327
+ new_fls[var] = values
328
+ return new_fls
444
329
  self._transform = jax.jit(_jax_wrapped_normalizer_transform)
445
330
 
446
331
 
@@ -478,40 +363,38 @@ class JaxPlan(metaclass=ABCMeta):
478
363
  pass
479
364
 
480
365
  @property
481
- def initializer(self):
366
+ def initializer(self) -> Callable:
482
367
  return self._initializer
483
368
 
484
369
  @initializer.setter
485
- def initializer(self, value):
370
+ def initializer(self, value: Callable) -> None:
486
371
  self._initializer = value
487
372
 
488
373
  @property
489
- def train_policy(self):
374
+ def train_policy(self) -> Callable:
490
375
  return self._train_policy
491
376
 
492
377
  @train_policy.setter
493
- def train_policy(self, value):
378
+ def train_policy(self, value: Callable) -> None:
494
379
  self._train_policy = value
495
380
 
496
381
  @property
497
- def test_policy(self):
382
+ def test_policy(self) -> Callable:
498
383
  return self._test_policy
499
384
 
500
385
  @test_policy.setter
501
- def test_policy(self, value):
386
+ def test_policy(self, value: Callable) -> None:
502
387
  self._test_policy = value
503
388
 
504
389
  @property
505
- def projection(self):
390
+ def projection(self) -> Callable:
506
391
  return self._projection
507
392
 
508
393
  @projection.setter
509
- def projection(self, value):
394
+ def projection(self, value: Callable) -> None:
510
395
  self._projection = value
511
396
 
512
- def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
513
- user_bounds: Bounds,
514
- horizon: int):
397
+ def _calculate_action_info(self, compiled, user_bounds, horizon):
515
398
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
516
399
  for (name, prange) in compiled.rddl.variable_ranges.items():
517
400
  if compiled.rddl.variable_types[name] != 'action-fluent':
@@ -522,7 +405,8 @@ class JaxPlan(metaclass=ABCMeta):
522
405
  keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
523
406
  raise RDDLTypeError(
524
407
  f'Invalid range <{prange}> of action-fluent <{name}>, '
525
- f'must be one of {keys}.')
408
+ f'must be one of {keys}.'
409
+ )
526
410
 
527
411
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
528
412
  # or the user defined action bounds if provided
@@ -530,15 +414,22 @@ class JaxPlan(metaclass=ABCMeta):
530
414
  if prange == 'bool':
531
415
  lower, upper = None, None
532
416
  else:
417
+
418
+ # enum values are ordered from 0 to number of objects - 1
533
419
  if prange in compiled.rddl.enum_types:
534
420
  lower = np.zeros(shape=shapes[name][1:])
535
421
  upper = len(compiled.rddl.type_to_objects[prange]) - 1
536
422
  upper = np.ones(shape=shapes[name][1:]) * upper
537
423
  else:
538
424
  lower, upper = compiled.constraints.bounds[name]
425
+
426
+ # override with user defined bounds
539
427
  lower, upper = user_bounds.get(name, (lower, upper))
540
428
  lower = np.asarray(lower, dtype=compiled.REAL)
541
429
  upper = np.asarray(upper, dtype=compiled.REAL)
430
+
431
+ # get masks for a jax conditional statement to avoid numerical errors
432
+ # for infinite values
542
433
  lower_finite = np.isfinite(lower)
543
434
  upper_finite = np.isfinite(upper)
544
435
  bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
@@ -548,21 +439,173 @@ class JaxPlan(metaclass=ABCMeta):
548
439
  ~lower_finite & upper_finite,
549
440
  ~lower_finite & ~upper_finite]
550
441
  bounds[name] = (lower, upper)
442
+
551
443
  if compiled.print_warnings:
552
- message = termcolor.colored(
444
+ print(termcolor.colored(
553
445
  f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
554
- 'green')
555
- print(message)
446
+ 'dark_grey'
447
+ ))
556
448
  return shapes, bounds, bounds_safe, cond_lists
557
449
 
558
- def _count_bool_actions(self, rddl: RDDLLiftedModel):
450
+ def _count_bool_actions(self, rddl):
559
451
  constraint = rddl.max_allowed_actions
560
452
  num_bool_actions = sum(np.size(values)
561
453
  for (var, values) in rddl.action_fluents.items()
562
454
  if rddl.variable_ranges[var] == 'bool')
563
455
  return num_bool_actions, constraint
564
456
 
457
+
458
+ class JaxActionProjection(metaclass=ABCMeta):
459
+ '''Base of all straight-line plan action projections.'''
460
+
461
+ @abstractmethod
462
+ def compile(self, *args, **kwargs) -> Callable:
463
+ pass
565
464
 
465
+
466
+ class JaxSortingActionProjection(JaxActionProjection):
467
+ '''Action projection using sorting method.'''
468
+
469
+ def compile(self, ranges: Dict[str, str], noop: Dict[str, Any],
470
+ wrap_sigmoid: bool, allowed_actions: int, bool_threshold: float,
471
+ jax_bool_to_box: Callable, *args, **kwargs) -> Callable:
472
+
473
+ # shift the boolean actions uniformly, clipping at the min/max values
474
+ # the amount to move is such that only top allowed_actions actions
475
+ # are still active (e.g. not equal to noop) after the shift
476
+ def _jax_wrapped_sorting_project(params, hyperparams):
477
+
478
+ # find the amount to shift action parameters: if noop=True reflect parameter
479
+ scores = []
480
+ for (var, param) in params.items():
481
+ if ranges[var] == 'bool':
482
+ param_flat = jnp.ravel(param, order='C')
483
+ if noop[var]:
484
+ if wrap_sigmoid:
485
+ param_flat = -param_flat
486
+ else:
487
+ param_flat = 1.0 - param_flat
488
+ scores.append(param_flat)
489
+ scores = jnp.concatenate(scores)
490
+ descending = jnp.sort(scores)[::-1]
491
+ kplus1st_greatest = descending[allowed_actions]
492
+ surplus = jnp.maximum(kplus1st_greatest - bool_threshold, 0.0)
493
+
494
+ # perform the shift
495
+ new_params = {}
496
+ for (var, param) in params.items():
497
+ if ranges[var] == 'bool':
498
+ if noop[var]:
499
+ new_param = param + surplus
500
+ else:
501
+ new_param = param - surplus
502
+ new_params[var] = jax_bool_to_box(var, new_param, hyperparams)
503
+ else:
504
+ new_params[var] = param
505
+ converged = jnp.array(True, dtype=jnp.bool_)
506
+ return new_params, converged
507
+ return _jax_wrapped_sorting_project
508
+
509
+
510
+ class JaxSogbofaActionProjection(JaxActionProjection):
511
+ '''Action projection using the SOGBOFA method.'''
512
+
513
+ def compile(self, ranges: Dict[str, str], noop: Dict[str, Any],
514
+ allowed_actions: int, max_constraint_iter: int,
515
+ jax_param_to_action: Callable, jax_action_to_param: Callable,
516
+ min_action: float, max_action: float, real_dtype: type, *args, **kwargs) -> Callable:
517
+
518
+ # calculate the surplus of actions above max-nondef-actions
519
+ def _jax_wrapped_sogbofa_surplus(actions):
520
+ sum_action = jnp.array(0.0, dtype=real_dtype)
521
+ k = jnp.array(0, dtype=jnp.int32)
522
+ for (var, action) in actions.items():
523
+ if ranges[var] == 'bool':
524
+ if noop[var]:
525
+ action = 1 - action
526
+ sum_action = sum_action + jnp.sum(action)
527
+ k = k + jnp.count_nonzero(action)
528
+ surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
529
+ return surplus, k
530
+
531
+ # return whether the surplus is positive or reached compute limit
532
+ def _jax_wrapped_sogbofa_continue(values):
533
+ it, _, surplus, k = values
534
+ return jnp.logical_and(
535
+ it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
536
+
537
+ # reduce all bool action values by the surplus clipping at minimum
538
+ # for no-op = True, do the opposite, i.e. increase all
539
+ # bool action values by surplus clipping at maximum
540
+ def _jax_wrapped_sogbofa_subtract_surplus(values):
541
+ it, actions, surplus, k = values
542
+ amount = surplus / k
543
+ new_actions = {}
544
+ for (var, action) in actions.items():
545
+ if ranges[var] == 'bool':
546
+ if noop[var]:
547
+ new_actions[var] = jnp.minimum(action + amount, 1)
548
+ else:
549
+ new_actions[var] = jnp.maximum(action - amount, 0)
550
+ else:
551
+ new_actions[var] = action
552
+ new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
553
+ new_it = it + 1
554
+ return new_it, new_actions, new_surplus, new_k
555
+
556
+ # apply the surplus to the actions until it becomes zero
557
+ def _jax_wrapped_sogbofa_project(params, hyperparams):
558
+
559
+ # convert parameters to actions
560
+ actions = {}
561
+ for (var, param) in params.items():
562
+ if ranges[var] == 'bool':
563
+ actions[var] = jax_param_to_action(var, param, hyperparams)
564
+ else:
565
+ actions[var] = param
566
+
567
+ # run SOGBOFA loop on the actions to get adjusted actions
568
+ surplus, k = _jax_wrapped_sogbofa_surplus(actions)
569
+ _, actions, surplus, k = jax.lax.while_loop(
570
+ cond_fun=_jax_wrapped_sogbofa_continue,
571
+ body_fun=_jax_wrapped_sogbofa_subtract_surplus,
572
+ init_val=(0, actions, surplus, k)
573
+ )
574
+ converged = jnp.logical_not(surplus > 0)
575
+
576
+ # check for any remaining constraint violation
577
+ total_bool = jnp.array(0, dtype=jnp.int32)
578
+ for (var, action) in actions.items():
579
+ if ranges[var] == 'bool':
580
+ if noop[var]:
581
+ total_bool = total_bool + jnp.count_nonzero(action < 0.5)
582
+ else:
583
+ total_bool = total_bool + jnp.count_nonzero(action > 0.5)
584
+ excess = jnp.maximum(total_bool - allowed_actions, 0)
585
+
586
+ # convert the adjusted actions back to parameters
587
+ # reduce the excess number of parameters that are non-noop above constraint
588
+ new_params = {}
589
+ for (var, action) in actions.items():
590
+ if ranges[var] == 'bool':
591
+ action = jnp.clip(action, min_action, max_action)
592
+ flat_action = jnp.ravel(action, order='C')
593
+ if noop[var]:
594
+ ranks = jnp.cumsum(flat_action < 0.5)
595
+ replace_mask = (flat_action < 0.5) & (ranks <= excess)
596
+ else:
597
+ ranks = jnp.cumsum(flat_action > 0.5)
598
+ replace_mask = (flat_action > 0.5) & (ranks <= excess)
599
+ flat_action = jnp.where(replace_mask, 0.5, flat_action)
600
+ action = jnp.reshape(flat_action, jnp.shape(action))
601
+ new_params[var] = jax_action_to_param(var, action, hyperparams)
602
+ excess = jnp.maximum(excess - jnp.count_nonzero(replace_mask), 0)
603
+ else:
604
+ new_params[var] = action
605
+ return new_params, converged
606
+ return _jax_wrapped_sogbofa_project
607
+
608
+
566
609
  class JaxStraightLinePlan(JaxPlan):
567
610
  '''A straight line plan implementation in JAX'''
568
611
 
@@ -607,7 +650,7 @@ class JaxStraightLinePlan(JaxPlan):
607
650
  def __str__(self) -> str:
608
651
  bounds = '\n '.join(
609
652
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
610
- return (f'policy hyper-parameters:\n'
653
+ return (f'[INFO] policy hyper-parameters:\n'
611
654
  f' initializer={self._initializer_base}\n'
612
655
  f' constraint-sat strategy (simple):\n'
613
656
  f' parsed_action_bounds =\n {bounds}\n'
@@ -630,16 +673,7 @@ class JaxStraightLinePlan(JaxPlan):
630
673
  compiled, _bounds, horizon)
631
674
  self.bounds = bounds
632
675
 
633
- # action concurrency check
634
- bool_action_count, allowed_actions = self._count_bool_actions(rddl)
635
- use_constraint_satisfaction = allowed_actions < bool_action_count
636
- if compiled.print_warnings and use_constraint_satisfaction:
637
- message = termcolor.colored(
638
- f'[INFO] SLP will use projected gradient to satisfy '
639
- f'max_nondef_actions since total boolean actions '
640
- f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
641
- print(message)
642
-
676
+ # get the noop action values
643
677
  noop = {var: (values[0] if isinstance(values, list) else values)
644
678
  for (var, values) in rddl.action_fluents.items()}
645
679
  bool_key = 'bool__'
@@ -649,14 +683,18 @@ class JaxStraightLinePlan(JaxPlan):
649
683
  #
650
684
  # ***********************************************************************
651
685
 
652
- # define the mapping between trainable parameter and action
686
+ # boolean actions are parameters wrapped by sigmoid to ensure [0, 1]:
687
+ #
688
+ # action = sigmoid(weight * param)
689
+ #
690
+ # here weight is a hyper-parameter and param is the trainable policy parameter
653
691
  wrap_sigmoid = self._wrap_sigmoid
654
692
  bool_threshold = 0.0 if wrap_sigmoid else 0.5
655
693
 
656
694
  def _jax_bool_param_to_action(var, param, hyperparams):
657
695
  if wrap_sigmoid:
658
696
  weight = hyperparams[var]
659
- return jax.nn.sigmoid(weight * param)
697
+ return stable_sigmoid(weight * param)
660
698
  else:
661
699
  return param
662
700
 
@@ -666,7 +704,10 @@ class JaxStraightLinePlan(JaxPlan):
666
704
  return jax.scipy.special.logit(action) / weight
667
705
  else:
668
706
  return action
669
-
707
+
708
+ # the same technique could be applied to non-bool actions following Bueno et al.
709
+ # this is disabled by default since the gradient projection trick seems to work
710
+ # better, especially for one-sided bounds (-inf, B) or (B, +inf)
670
711
  wrap_non_bool = self._wrap_non_bool
671
712
 
672
713
  def _jax_non_bool_param_to_action(var, param, hyperparams):
@@ -675,45 +716,24 @@ class JaxStraightLinePlan(JaxPlan):
675
716
  mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
676
717
  for mask in cond_lists[var]]
677
718
  action = (
678
- mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
679
- ml * (lower + (jax.nn.elu(param) + 1.0)) +
680
- mu * (upper - (jax.nn.elu(-param) + 1.0)) +
719
+ mb * (lower + (upper - lower) * stable_sigmoid(param)) +
720
+ ml * (lower + jax.nn.softplus(param)) +
721
+ mu * (upper - jax.nn.softplus(-param)) +
681
722
  mn * param
682
723
  )
683
724
  else:
684
725
  action = param
685
726
  return action
686
-
687
- # handle box constraints
688
- min_action = self._min_action_prob
689
- max_action = 1.0 - min_action
690
-
691
- def _jax_project_bool_to_box(var, param, hyperparams):
692
- lower = _jax_bool_action_to_param(var, min_action, hyperparams)
693
- upper = _jax_bool_action_to_param(var, max_action, hyperparams)
694
- valid_param = jnp.clip(param, lower, upper)
695
- return valid_param
696
-
727
+
728
+ # a different option to handle boolean action concurrency constraints with |A| = 1
729
+ # is to use a softmax activation layer over pooled action parameters
697
730
  ranges = rddl.variable_ranges
698
-
699
- def _jax_wrapped_slp_project_to_box(params, hyperparams):
700
- new_params = {}
701
- for (var, param) in params.items():
702
- if var == bool_key:
703
- new_params[var] = param
704
- elif ranges[var] == 'bool':
705
- new_params[var] = _jax_project_bool_to_box(var, param, hyperparams)
706
- elif wrap_non_bool:
707
- new_params[var] = param
708
- else:
709
- new_params[var] = jnp.clip(param, *bounds[var])
710
- return new_params, True
711
-
712
- # convert softmax action back to action dict
713
731
  action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
714
732
  for (var, shape) in shapes.items()
715
733
  if ranges[var] == 'bool'}
716
734
 
735
+ # given a softmax output, this simply unpacks the result of the softmax back into
736
+ # the original action fluent dictionary
717
737
  def _jax_unstack_bool_from_softmax(output):
718
738
  actions = {}
719
739
  start = 0
@@ -723,11 +743,12 @@ class JaxStraightLinePlan(JaxPlan):
723
743
  if noop[name]:
724
744
  action = 1.0 - action
725
745
  actions[name] = action
726
- start += size
746
+ start = start + size
727
747
  return actions
728
748
 
729
- # train plan prediction (TODO: implement one-hot for integer actions)
730
- def _jax_wrapped_slp_predict_train(key, params, hyperparams, step, subs):
749
+ # the main subroutine to compute the trainable rddl actions from the trainable
750
+ # parameters (TODO: implement one-hot for integer actions)
751
+ def _jax_wrapped_slp_predict_train(key, params, hyperparams, step, fls):
731
752
  actions = {}
732
753
  for (var, param) in params.items():
733
754
  action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
@@ -740,9 +761,12 @@ class JaxStraightLinePlan(JaxPlan):
740
761
  else:
741
762
  actions[var] = _jax_non_bool_param_to_action(var, action, hyperparams)
742
763
  return actions
764
+ self.train_policy = _jax_wrapped_slp_predict_train
743
765
 
744
- # test plan prediction
745
- def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, subs):
766
+ # the main subroutine to compute the test rddl actions from the trainable
767
+ # parameters: the difference here is that actions are converted to their required
768
+ # types (i.e. bool, int, float)
769
+ def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, fls):
746
770
  actions = {}
747
771
  for (var, param) in params.items():
748
772
  action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
@@ -760,8 +784,6 @@ class JaxStraightLinePlan(JaxPlan):
760
784
  action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
761
785
  actions[var] = action
762
786
  return actions
763
-
764
- self.train_policy = _jax_wrapped_slp_predict_train
765
787
  self.test_policy = _jax_wrapped_slp_predict_test
766
788
 
767
789
  # ***********************************************************************
@@ -769,148 +791,76 @@ class JaxStraightLinePlan(JaxPlan):
769
791
  #
770
792
  # ***********************************************************************
771
793
 
772
- # use a softmax output activation
794
+ # if the user wants min/max values for clipping boolean action parameters
795
+ # this might be a good idea to avoid saturation of action-fluents since the
796
+ # gradient could vanish as a result
797
+ min_action = self._min_action_prob
798
+ max_action = 1.0 - min_action
799
+
800
+ def _jax_project_bool_to_box(var, param, hyperparams):
801
+ lower = _jax_bool_action_to_param(var, min_action, hyperparams)
802
+ upper = _jax_bool_action_to_param(var, max_action, hyperparams)
803
+ return jnp.clip(param, lower, upper)
804
+
805
+ def _jax_wrapped_slp_project_to_box(params, hyperparams):
806
+ new_params = {}
807
+ for (var, param) in params.items():
808
+ if var == bool_key:
809
+ new_params[var] = param
810
+ elif ranges[var] == 'bool':
811
+ new_params[var] = _jax_project_bool_to_box(var, param, hyperparams)
812
+ elif wrap_non_bool:
813
+ new_params[var] = param
814
+ else:
815
+ new_params[var] = jnp.clip(param, *bounds[var])
816
+ converged = jnp.array(True, dtype=jnp.bool_)
817
+ return new_params, converged
818
+
819
+ # enable constraint satisfaction subroutines during optimization
820
+ # if there are nontrivial concurrency constraints in the problem description
821
+ bool_action_count, allowed_actions = self._count_bool_actions(rddl)
822
+ use_constraint_satisfaction = allowed_actions < bool_action_count
823
+ if compiled.print_warnings and use_constraint_satisfaction:
824
+ print(termcolor.colored(
825
+ f'[INFO] Number of boolean actions {bool_action_count} '
826
+ f'> max_nondef_actions {allowed_actions}: enabling projected gradient to '
827
+ f'satisfy constraints on action-fluents.', 'dark_grey'
828
+ ))
829
+
830
+ # use a softmax output activation: only allow one action non-noop for now
773
831
  if use_constraint_satisfaction and self._wrap_softmax:
774
-
775
- # only allow one action non-noop for now
776
832
  if 1 < allowed_actions < bool_action_count:
777
833
  raise RDDLNotImplementedError(
778
834
  f'SLPs with wrap_softmax currently '
779
- f'do not support max-nondef-actions {allowed_actions} > 1.')
780
-
781
- # potentially apply projection but to non-bool actions only
835
+ f'do not support max-nondef-actions {allowed_actions} > 1.'
836
+ )
782
837
  self.projection = _jax_wrapped_slp_project_to_box
783
838
 
784
- # use new gradient projection method...
839
+ # use new gradient projection method
785
840
  elif use_constraint_satisfaction and self._use_new_projection:
786
-
787
- # shift the boolean actions uniformly, clipping at the min/max values
788
- # the amount to move is such that only top allowed_actions actions
789
- # are still active (e.g. not equal to noop) after the shift
790
- def _jax_wrapped_sorting_project(params, hyperparams):
791
-
792
- # find the amount to shift action parameters
793
- # if noop is True pretend it is False and reflect the parameter
794
- scores = []
795
- for (var, param) in params.items():
796
- if ranges[var] == 'bool':
797
- param_flat = jnp.ravel(param, order='C')
798
- if noop[var]:
799
- if wrap_sigmoid:
800
- param_flat = -param_flat
801
- else:
802
- param_flat = 1.0 - param_flat
803
- scores.append(param_flat)
804
- scores = jnp.concatenate(scores)
805
- descending = jnp.sort(scores)[::-1]
806
- kplus1st_greatest = descending[allowed_actions]
807
- surplus = jnp.maximum(kplus1st_greatest - bool_threshold, 0.0)
808
-
809
- # perform the shift
810
- new_params = {}
811
- for (var, param) in params.items():
812
- if ranges[var] == 'bool':
813
- if noop[var]:
814
- new_param = param + surplus
815
- else:
816
- new_param = param - surplus
817
- new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
818
- else:
819
- new_param = param
820
- new_params[var] = new_param
821
- return new_params, True
822
-
841
+ jax_project_fn = JaxSortingActionProjection().compile(
842
+ ranges, noop, wrap_sigmoid, allowed_actions, bool_threshold,
843
+ _jax_project_bool_to_box
844
+ )
845
+
823
846
  # clip actions to valid bounds and satisfy constraint on max actions
824
847
  def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
825
848
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
826
- project_over_horizon = jax.vmap(
827
- _jax_wrapped_sorting_project, in_axes=(0, None)
828
- )(params, hyperparams)
829
- return project_over_horizon
830
-
849
+ return jax.vmap(jax_project_fn, in_axes=(0, None))(params, hyperparams)
831
850
  self.projection = _jax_wrapped_slp_project_to_max_constraint
832
851
 
833
- # use SOGBOFA projection method...
852
+ # use SOGBOFA projection method
834
853
  elif use_constraint_satisfaction and not self._use_new_projection:
854
+ jax_project_fn = JaxSogbofaActionProjection().compile(
855
+ ranges, noop, allowed_actions, self._max_constraint_iter,
856
+ _jax_bool_param_to_action, _jax_bool_action_to_param,
857
+ min_action, max_action, compiled.REAL
858
+ )
835
859
 
836
- # calculate the surplus of actions above max-nondef-actions
837
- def _jax_wrapped_sogbofa_surplus(actions):
838
- sum_action, k = 0.0, 0
839
- for (var, action) in actions.items():
840
- if ranges[var] == 'bool':
841
- if noop[var]:
842
- action = 1 - action
843
- sum_action += jnp.sum(action)
844
- k += jnp.count_nonzero(action)
845
- surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
846
- return surplus, k
847
-
848
- # return whether the surplus is positive or reached compute limit
849
- max_constraint_iter = self._max_constraint_iter
850
-
851
- def _jax_wrapped_sogbofa_continue(values):
852
- it, _, surplus, k = values
853
- return jnp.logical_and(
854
- it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
855
-
856
- # reduce all bool action values by the surplus clipping at minimum
857
- # for no-op = True, do the opposite, i.e. increase all
858
- # bool action values by surplus clipping at maximum
859
- def _jax_wrapped_sogbofa_subtract_surplus(values):
860
- it, actions, surplus, k = values
861
- amount = surplus / k
862
- new_actions = {}
863
- for (var, action) in actions.items():
864
- if ranges[var] == 'bool':
865
- if noop[var]:
866
- new_actions[var] = jnp.minimum(action + amount, 1)
867
- else:
868
- new_actions[var] = jnp.maximum(action - amount, 0)
869
- else:
870
- new_actions[var] = action
871
- new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
872
- new_it = it + 1
873
- return new_it, new_actions, new_surplus, new_k
874
-
875
- # apply the surplus to the actions until it becomes zero
876
- def _jax_wrapped_sogbofa_project(params, hyperparams):
877
-
878
- # convert parameters to actions
879
- actions = {}
880
- for (var, param) in params.items():
881
- if ranges[var] == 'bool':
882
- actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
883
- else:
884
- actions[var] = param
885
-
886
- # run SOGBOFA loop on the actions to get adjusted actions
887
- surplus, k = _jax_wrapped_sogbofa_surplus(actions)
888
- _, actions, surplus, k = jax.lax.while_loop(
889
- cond_fun=_jax_wrapped_sogbofa_continue,
890
- body_fun=_jax_wrapped_sogbofa_subtract_surplus,
891
- init_val=(0, actions, surplus, k)
892
- )
893
- converged = jnp.logical_not(surplus > 0)
894
-
895
- # convert the adjusted actions back to parameters
896
- new_params = {}
897
- for (var, action) in actions.items():
898
- if ranges[var] == 'bool':
899
- action = jnp.clip(action, min_action, max_action)
900
- param = _jax_bool_action_to_param(var, action, hyperparams)
901
- new_params[var] = param
902
- else:
903
- new_params[var] = action
904
- return new_params, converged
905
-
906
860
  # clip actions to valid bounds and satisfy constraint on max actions
907
861
  def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
908
862
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
909
- project_over_horizon = jax.vmap(
910
- _jax_wrapped_sogbofa_project, in_axes=(0, None)
911
- )(params, hyperparams)
912
- return project_over_horizon
913
-
863
+ return jax.vmap(jax_project_fn, in_axes=(0, None))(params, hyperparams)
914
864
  self.projection = _jax_wrapped_slp_project_to_max_constraint
915
865
 
916
866
  # just project to box constraints
@@ -925,34 +875,32 @@ class JaxStraightLinePlan(JaxPlan):
925
875
  init = self._initializer
926
876
  stack_bool_params = use_constraint_satisfaction and self._wrap_softmax
927
877
 
928
- def _jax_wrapped_slp_init(key, hyperparams, subs):
878
+ # use the user required initializer and project actions to feasible range
879
+ def _jax_wrapped_slp_init(key, hyperparams, fls):
929
880
  params = {}
930
881
  for (var, shape) in shapes.items():
931
882
  if ranges[var] != 'bool' or not stack_bool_params:
932
883
  key, subkey = random.split(key)
933
884
  param = init(key=subkey, shape=shape, dtype=compiled.REAL)
934
885
  if ranges[var] == 'bool':
935
- param += bool_threshold
886
+ param = param + bool_threshold
936
887
  params[var] = param
937
888
  if stack_bool_params:
938
889
  key, subkey = random.split(key)
939
890
  bool_shape = (horizon, bool_action_count)
940
- bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
941
- params[bool_key] = bool_param
891
+ params[bool_key] = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
942
892
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
943
893
  return params
944
-
945
894
  self.initializer = _jax_wrapped_slp_init
946
895
 
947
896
  @staticmethod
948
897
  @jax.jit
949
898
  def _guess_next_epoch(param):
950
- # "progress" the plan one step forward and set last action to second-last
951
899
  return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
952
900
 
953
901
  def guess_next_epoch(self, params: Pytree) -> Pytree:
954
- next_fn = JaxStraightLinePlan._guess_next_epoch
955
- return jax.tree_util.tree_map(next_fn, params)
902
+ # "progress" the plan one step forward and set last action to second-last
903
+ return jax.tree_util.tree_map(JaxStraightLinePlan._guess_next_epoch, params)
956
904
 
957
905
 
958
906
  class JaxDeepReactivePolicy(JaxPlan):
@@ -997,7 +945,7 @@ class JaxDeepReactivePolicy(JaxPlan):
997
945
  def __str__(self) -> str:
998
946
  bounds = '\n '.join(
999
947
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
1000
- return (f'policy hyper-parameters:\n'
948
+ return (f'[INFO] policy hyper-parameters:\n'
1001
949
  f' topology ={self._topology}\n'
1002
950
  f' activation_fn={self._activations[0].__name__}\n'
1003
951
  f' initializer ={type(self._initializer_base).__name__}\n'
@@ -1021,13 +969,18 @@ class JaxDeepReactivePolicy(JaxPlan):
1021
969
  shapes = {var: value[1:] for (var, value) in shapes.items()}
1022
970
  self.bounds = bounds
1023
971
 
1024
- # action concurrency check - only allow one action non-noop for now
972
+ # enable constraint satisfaction subroutines during optimization
973
+ # if there are nontrivial concurrency constraints in the problem description
974
+ # only handles the case where |A| = 1 for now, as there is no way to do projection
975
+ # currently (TODO: fix this)
1025
976
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
1026
977
  if 1 < allowed_actions < bool_action_count:
1027
978
  raise RDDLNotImplementedError(
1028
- f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.')
979
+ f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.'
980
+ )
1029
981
  use_constraint_satisfaction = allowed_actions < bool_action_count
1030
-
982
+
983
+ # get the noop action values
1031
984
  noop = {var: (values[0] if isinstance(values, list) else values)
1032
985
  for (var, values) in rddl.action_fluents.items()}
1033
986
  bool_key = 'bool__'
@@ -1036,7 +989,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1036
989
  # POLICY NETWORK PREDICTION
1037
990
  #
1038
991
  # ***********************************************************************
1039
-
992
+
993
+ # compute the correct shapes of the output layers based on the action-fluent shape
1040
994
  ranges = rddl.variable_ranges
1041
995
  normalize = self._normalize
1042
996
  normalize_per_layer = self._normalize_per_layer
@@ -1047,14 +1001,15 @@ class JaxDeepReactivePolicy(JaxPlan):
1047
1001
  for (var, shape) in shapes.items()}
1048
1002
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
1049
1003
 
1050
- # inputs for the policy network
1004
+ # inputs for the policy network are states for fully observed and obs for POMDPs
1051
1005
  if rddl.observ_fluents:
1052
1006
  observed_vars = rddl.observ_fluents
1053
1007
  else:
1054
1008
  observed_vars = rddl.state_fluents
1055
1009
  input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
1056
1010
 
1057
- # catch if input norm is applied to size 1 tensor
1011
+ # catch if input norm is applied to size 1 tensor:
1012
+ # this leads to incorrect behavior as the input is always "1"
1058
1013
  if normalize:
1059
1014
  non_bool_dims = 0
1060
1015
  for (var, values) in observed_vars.items():
@@ -1062,33 +1017,33 @@ class JaxDeepReactivePolicy(JaxPlan):
1062
1017
  value_size = np.size(values)
1063
1018
  if normalize_per_layer and value_size == 1:
1064
1019
  if compiled.print_warnings:
1065
- message = termcolor.colored(
1020
+ print(termcolor.colored(
1066
1021
  f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
1067
- f'of size 1: setting normalize_per_layer = False.', 'yellow')
1068
- print(message)
1022
+ f'of size 1: setting normalize_per_layer = False.', 'yellow'
1023
+ ))
1069
1024
  normalize_per_layer = False
1070
1025
  non_bool_dims += value_size
1071
1026
  if not normalize_per_layer and non_bool_dims == 1:
1072
1027
  if compiled.print_warnings:
1073
- message = termcolor.colored(
1028
+ print(termcolor.colored(
1074
1029
  '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
1075
- 'setting normalize = False.', 'yellow')
1076
- print(message)
1030
+ 'setting normalize = False.', 'yellow'
1031
+ ))
1077
1032
  normalize = False
1078
1033
 
1079
- # convert subs dictionary into a state vector to feed to the MLP
1080
- def _jax_wrapped_policy_input(subs, hyperparams):
1034
+ # convert fluents dictionary into a state vector to feed to the MLP
1035
+ def _jax_wrapped_policy_input(fls, hyperparams):
1081
1036
 
1082
1037
  # optional state preprocessing
1083
1038
  if preprocessor is not None:
1084
1039
  stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
1085
- subs = preprocessor.transform(subs, stats)
1040
+ fls = preprocessor.transform(fls, stats)
1086
1041
 
1087
1042
  # concatenate all state variables into a single vector
1088
1043
  # optionally apply layer norm to each input tensor
1089
1044
  states_bool, states_non_bool = [], []
1090
1045
  non_bool_dims = 0
1091
- for (var, value) in subs.items():
1046
+ for (var, value) in fls.items():
1092
1047
  if var in observed_vars:
1093
1048
  state = jnp.ravel(value, order='C')
1094
1049
  if ranges[var] == 'bool':
@@ -1103,7 +1058,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1103
1058
  )
1104
1059
  state = normalizer(state)
1105
1060
  states_non_bool.append(state)
1106
- non_bool_dims += state.size
1061
+ non_bool_dims = non_bool_dims + state.size
1107
1062
  state = jnp.concatenate(states_non_bool + states_bool)
1108
1063
 
1109
1064
  # optionally perform layer normalization on the non-bool inputs
@@ -1119,8 +1074,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1119
1074
  return state
1120
1075
 
1121
1076
  # predict actions from the policy network for current state
1122
- def _jax_wrapped_policy_network_predict(subs, hyperparams):
1123
- state = _jax_wrapped_policy_input(subs, hyperparams)
1077
+ def _jax_wrapped_policy_network_predict(fls, hyperparams):
1078
+ state = _jax_wrapped_policy_input(fls, hyperparams)
1124
1079
 
1125
1080
  # feed state vector through hidden layers
1126
1081
  hidden = state
@@ -1139,37 +1094,37 @@ class JaxDeepReactivePolicy(JaxPlan):
1139
1094
  if not shapes[var]:
1140
1095
  output = jnp.squeeze(output)
1141
1096
 
1142
- # project action output to valid box constraints
1097
+ # project action output to valid box constraints following Bueno et. al.
1143
1098
  if ranges[var] == 'bool':
1144
1099
  if not use_constraint_satisfaction:
1145
- actions[var] = jax.nn.sigmoid(output)
1100
+ actions[var] = stable_sigmoid(output)
1146
1101
  else:
1147
1102
  if wrap_non_bool:
1148
1103
  lower, upper = bounds_safe[var]
1149
1104
  mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
1150
1105
  for mask in cond_lists[var]]
1151
- action = (
1152
- mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
1153
- ml * (lower + (jax.nn.elu(output) + 1.0)) +
1154
- mu * (upper - (jax.nn.elu(-output) + 1.0)) +
1106
+ actions[var] = (
1107
+ mb * (lower + (upper - lower) * stable_sigmoid(output)) +
1108
+ ml * (lower + jax.nn.softplus(output)) +
1109
+ mu * (upper - jax.nn.softplus(-output)) +
1155
1110
  mn * output
1156
1111
  )
1157
1112
  else:
1158
- action = output
1159
- actions[var] = action
1113
+ actions[var] = output
1160
1114
 
1161
- # for constraint satisfaction wrap bool actions with softmax
1115
+ # for constraint satisfaction wrap bool actions with softmax:
1116
+ # this only works when |A| = 1
1162
1117
  if use_constraint_satisfaction:
1163
1118
  linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
1164
- output = jax.nn.softmax(linear(hidden))
1165
- actions[bool_key] = output
1166
-
1119
+ actions[bool_key] = jax.nn.softmax(linear(hidden))
1167
1120
  return actions
1168
1121
 
1122
+ # we need pure JAX functions for the policy network prediction
1169
1123
  predict_fn = hk.transform(_jax_wrapped_policy_network_predict)
1170
1124
  predict_fn = hk.without_apply_rng(predict_fn)
1171
1125
 
1172
- # convert softmax action back to action dict
1126
+ # given a softmax output, this simply unpacks the result of the softmax back into
1127
+ # the original action fluent dictionary
1173
1128
  def _jax_unstack_bool_from_softmax(output):
1174
1129
  actions = {}
1175
1130
  start = 0
@@ -1180,12 +1135,13 @@ class JaxDeepReactivePolicy(JaxPlan):
1180
1135
  if noop[name]:
1181
1136
  action = 1.0 - action
1182
1137
  actions[name] = action
1183
- start += size
1138
+ start = start + size
1184
1139
  return actions
1185
1140
 
1186
- # train action prediction
1187
- def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
1188
- actions = predict_fn.apply(params, subs, hyperparams)
1141
+ # the main subroutine to compute the trainable rddl actions from the trainable
1142
+ # parameters and the current state/obs dictionary
1143
+ def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, fls):
1144
+ actions = predict_fn.apply(params, fls, hyperparams)
1189
1145
  if not wrap_non_bool:
1190
1146
  for (var, action) in actions.items():
1191
1147
  if var != bool_key and ranges[var] != 'bool':
@@ -1195,10 +1151,13 @@ class JaxDeepReactivePolicy(JaxPlan):
1195
1151
  actions.update(bool_actions)
1196
1152
  del actions[bool_key]
1197
1153
  return actions
1154
+ self.train_policy = _jax_wrapped_drp_predict_train
1198
1155
 
1199
- # test action prediction
1200
- def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
1201
- actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
1156
+ # the main subroutine to compute the test rddl actions from the trainable
1157
+ # parameters and state/obs dict: the difference here is that actions are converted
1158
+ # to their required types (i.e. bool, int, float)
1159
+ def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, fls):
1160
+ actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, fls)
1202
1161
  new_actions = {}
1203
1162
  for (var, action) in actions.items():
1204
1163
  prange = ranges[var]
@@ -1211,8 +1170,6 @@ class JaxDeepReactivePolicy(JaxPlan):
1211
1170
  new_action = jnp.clip(action, *bounds[var])
1212
1171
  new_actions[var] = new_action
1213
1172
  return new_actions
1214
-
1215
- self.train_policy = _jax_wrapped_drp_predict_train
1216
1173
  self.test_policy = _jax_wrapped_drp_predict_test
1217
1174
 
1218
1175
  # ***********************************************************************
@@ -1222,8 +1179,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1222
1179
 
1223
1180
  # no projection applied since the actions are already constrained
1224
1181
  def _jax_wrapped_drp_no_projection(params, hyperparams):
1225
- return params, True
1226
-
1182
+ converged = jnp.array(True, dtype=jnp.bool_)
1183
+ return params, converged
1227
1184
  self.projection = _jax_wrapped_drp_no_projection
1228
1185
 
1229
1186
  # ***********************************************************************
@@ -1231,16 +1188,16 @@ class JaxDeepReactivePolicy(JaxPlan):
1231
1188
  #
1232
1189
  # ***********************************************************************
1233
1190
 
1234
- def _jax_wrapped_drp_init(key, hyperparams, subs):
1235
- subs = {var: value[0, ...]
1236
- for (var, value) in subs.items()
1237
- if var in observed_vars}
1238
- params = predict_fn.init(key, subs, hyperparams)
1239
- return params
1240
-
1191
+ # initialize policy parameters according to user-desired weight initializer
1192
+ def _jax_wrapped_drp_init(key, hyperparams, fls):
1193
+ obs_vars = {var: value[0, ...]
1194
+ for (var, value) in fls.items()
1195
+ if var in observed_vars}
1196
+ return predict_fn.init(key, obs_vars, hyperparams)
1241
1197
  self.initializer = _jax_wrapped_drp_init
1242
1198
 
1243
1199
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1200
+ # this is easy: just warm-start from the previously obtained policy
1244
1201
  return params
1245
1202
 
1246
1203
 
@@ -1339,17 +1296,16 @@ class PGPE(metaclass=ABCMeta):
1339
1296
  self._update = None
1340
1297
 
1341
1298
  @property
1342
- def initialize(self):
1299
+ def initialize(self) -> Callable:
1343
1300
  return self._initializer
1344
1301
 
1345
1302
  @property
1346
- def update(self):
1303
+ def update(self) -> Callable:
1347
1304
  return self._update
1348
1305
 
1349
1306
  @abstractmethod
1350
1307
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1351
- print_warnings: bool,
1352
- parallel_updates: Optional[int]=None) -> None:
1308
+ print_warnings: bool, parallel_updates: int=1) -> None:
1353
1309
  pass
1354
1310
 
1355
1311
 
@@ -1414,11 +1370,10 @@ class GaussianPGPE(PGPE):
1414
1370
  mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1415
1371
  sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1416
1372
  except Exception as _:
1417
- message = termcolor.colored(
1418
- '[FAIL] Failed to inject hyperparameters into PGPE optimizer, '
1419
- 'rolling back to safer method: '
1420
- 'kl-divergence constraint will be disabled.', 'red')
1421
- print(message)
1373
+ print(termcolor.colored(
1374
+ '[WARN] Could not inject hyperparameters into PGPE optimizer: '
1375
+ 'kl-divergence constraint will be disabled.', 'yellow'
1376
+ ))
1422
1377
  mu_optimizer = optimizer(**optimizer_kwargs_mu)
1423
1378
  sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1424
1379
  max_kl_update = None
@@ -1426,7 +1381,7 @@ class GaussianPGPE(PGPE):
1426
1381
  self.max_kl = max_kl_update
1427
1382
 
1428
1383
  def __str__(self) -> str:
1429
- return (f'PGPE hyper-parameters:\n'
1384
+ return (f'[INFO] PGPE hyper-parameters:\n'
1430
1385
  f' method ={self.__class__.__name__}\n'
1431
1386
  f' batch_size ={self.batch_size}\n'
1432
1387
  f' init_sigma ={self.init_sigma}\n'
@@ -1444,9 +1399,11 @@ class GaussianPGPE(PGPE):
1444
1399
  f' max_kl_update ={self.max_kl}\n'
1445
1400
  )
1446
1401
 
1447
- def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1402
+ def compile(self, loss_fn: Callable,
1403
+ projection: Callable,
1404
+ real_dtype: Type,
1448
1405
  print_warnings: bool,
1449
- parallel_updates: Optional[int]=None) -> None:
1406
+ parallel_updates: int=1) -> None:
1450
1407
  sigma0 = self.init_sigma
1451
1408
  sigma_lo, sigma_hi = self.sigma_range
1452
1409
  scale_reward = self.scale_reward
@@ -1458,6 +1415,7 @@ class GaussianPGPE(PGPE):
1458
1415
  max_kl = self.max_kl
1459
1416
 
1460
1417
  # entropy regularization penalty is decayed exponentially by elapsed budget
1418
+ # this uses the optimizer progress (as percentage) to move the decay
1461
1419
  start_entropy_coeff = self.start_entropy_coeff
1462
1420
  if start_entropy_coeff == 0:
1463
1421
  entropy_coeff_decay = 0
@@ -1469,6 +1427,8 @@ class GaussianPGPE(PGPE):
1469
1427
  #
1470
1428
  # ***********************************************************************
1471
1429
 
1430
+ # use the default initializer for the (mean, sigma) parameters
1431
+ # these parameters define the sampling distribution over policy parameters
1472
1432
  def _jax_wrapped_pgpe_init(key, policy_params):
1473
1433
  mu = policy_params
1474
1434
  sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
@@ -1477,51 +1437,60 @@ class GaussianPGPE(PGPE):
1477
1437
  r_max = -jnp.inf
1478
1438
  return pgpe_params, pgpe_opt_state, r_max
1479
1439
 
1480
- if parallel_updates is None:
1481
- self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1482
- else:
1483
-
1484
- # for parallel policy update
1485
- def _jax_wrapped_pgpe_inits(key, policy_params):
1486
- keys = jnp.asarray(random.split(key, num=parallel_updates))
1487
- return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
1488
-
1489
- self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
1440
+ # for parallel policy update, initialize multiple indepdendent (mean, sigma)
1441
+ # gaussians that will be optimized in parallel
1442
+ def _jax_wrapped_batched_pgpe_init(key, policy_params):
1443
+ keys = random.split(key, num=parallel_updates)
1444
+ return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
1445
+
1446
+ self._initializer = jax.jit(_jax_wrapped_batched_pgpe_init)
1490
1447
 
1491
1448
  # ***********************************************************************
1492
1449
  # PARAMETER SAMPLING FUNCTIONS
1493
1450
  #
1494
1451
  # ***********************************************************************
1495
1452
 
1453
+ # sample from i.i.d. Normal(0, sigma)
1496
1454
  def _jax_wrapped_mu_noise(key, sigma):
1497
1455
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1498
1456
 
1499
1457
  # this samples a noise variable epsilon* from epsilon with the N(0, 1) density
1500
- # according to super-symmetric sampling paper
1458
+ # according to super-symmetric sampling paper:
1459
+ # the paper presents a more accurate formula which is used by default
1501
1460
  def _jax_wrapped_epsilon_star(sigma, epsilon):
1502
- c1, c2, c3 = -0.06655, -0.9706, 0.124
1503
1461
  phi = 0.67449 * sigma
1504
1462
  a = (sigma - jnp.abs(epsilon)) / sigma
1463
+
1464
+ # more accurate formula
1505
1465
  if super_symmetric_accurate:
1506
1466
  aa = jnp.abs(a)
1507
- aa3 = jnp.power(aa, 3)
1508
- epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1509
- a <= 0,
1510
- jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
1511
- jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
1512
- )
1467
+ atol = 1e-10
1468
+ c1, c2, c3 = -0.06655, -0.9706, 0.124
1469
+ term_neg_log = c1 * (aa * aa - 1.) / jnp.log(aa + atol) + c2
1470
+ term_pos_log = 1. - c3 * jnp.log1p(-aa ** 3 + atol)
1471
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(
1472
+ aa * jnp.where(a <= 0, term_neg_log, term_pos_log))
1473
+
1474
+ # less accurate and simple formula
1513
1475
  else:
1514
1476
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1515
1477
  return epsilon_star
1516
1478
 
1517
1479
  # implements baseline-free super-symmetric sampling to generate 4 trajectories
1480
+ # this type of sampling removes the need for the baseline completely
1518
1481
  def _jax_wrapped_sample_params(key, mu, sigma):
1482
+
1483
+ # this samples the basic two policy parameters from Gaussian(mean, sigma)
1484
+ # using the control variates
1519
1485
  treedef = jax.tree_util.tree_structure(sigma)
1520
1486
  keys = random.split(key, num=treedef.num_leaves)
1521
1487
  keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1522
1488
  epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1523
1489
  p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
1524
1490
  p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
1491
+
1492
+ # sumer-symmetric sampling removes the need for a baseline but requires
1493
+ # two additional policies to be sampled
1525
1494
  if super_symmetric:
1526
1495
  epsilon_star = jax.tree_util.tree_map(
1527
1496
  _jax_wrapped_epsilon_star, sigma, epsilon)
@@ -1538,6 +1507,8 @@ class GaussianPGPE(PGPE):
1538
1507
 
1539
1508
  # gradient with respect to mean
1540
1509
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1510
+
1511
+ # for super symmetric sampling
1541
1512
  if super_symmetric:
1542
1513
  if scale_reward:
1543
1514
  scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
@@ -1547,6 +1518,8 @@ class GaussianPGPE(PGPE):
1547
1518
  r_mu1 = (r1 - r2) / (2 * scale1)
1548
1519
  r_mu2 = (r3 - r4) / (2 * scale2)
1549
1520
  grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1521
+
1522
+ # for the basic pgpe
1550
1523
  else:
1551
1524
  if scale_reward:
1552
1525
  scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
@@ -1558,6 +1531,8 @@ class GaussianPGPE(PGPE):
1558
1531
 
1559
1532
  # gradient with respect to std. deviation
1560
1533
  def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1534
+
1535
+ # for super symmetric sampling
1561
1536
  if super_symmetric:
1562
1537
  mask = r1 + r2 >= r3 + r4
1563
1538
  epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
@@ -1567,6 +1542,8 @@ class GaussianPGPE(PGPE):
1567
1542
  else:
1568
1543
  scale = 1.0
1569
1544
  r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1545
+
1546
+ # for basic pgpe
1570
1547
  else:
1571
1548
  s = jnp.square(epsilon) / sigma - sigma
1572
1549
  if scale_reward:
@@ -1574,30 +1551,40 @@ class GaussianPGPE(PGPE):
1574
1551
  else:
1575
1552
  scale = 1.0
1576
1553
  r_sigma = (r1 + r2) / (2 * scale)
1577
- grad = -(r_sigma * s + ent / sigma)
1578
- return grad
1554
+
1555
+ return -(r_sigma * s + ent / sigma)
1579
1556
 
1580
1557
  # calculate the policy gradients
1581
1558
  def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1582
- policy_hyperparams, subs, model_params):
1559
+ policy_hyperparams, fls, nfls, model_params):
1560
+
1561
+ # basic pgpe sampling with return estimation
1583
1562
  key, subkey = random.split(key)
1584
1563
  p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
1585
1564
  key, mu, sigma)
1586
- r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1587
- r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
1565
+ r1 = -loss_fn(subkey, p1, policy_hyperparams, fls, nfls, model_params)[0]
1566
+ r2 = -loss_fn(subkey, p2, policy_hyperparams, fls, nfls, model_params)[0]
1567
+
1568
+ # do a return normalization for optimizer stability
1588
1569
  r_max = jnp.maximum(r_max, r1)
1589
1570
  r_max = jnp.maximum(r_max, r2)
1571
+
1572
+ # super symmetric sampling requires two more trajectories and their returns
1590
1573
  if super_symmetric:
1591
- r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
1592
- r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
1574
+ r3 = -loss_fn(subkey, p3, policy_hyperparams, fls, nfls, model_params)[0]
1575
+ r4 = -loss_fn(subkey, p4, policy_hyperparams, fls, nfls, model_params)[0]
1593
1576
  r_max = jnp.maximum(r_max, r3)
1594
1577
  r_max = jnp.maximum(r_max, r4)
1595
1578
  else:
1596
- r3, r4 = r1, r2
1579
+ r3, r4 = r1, r2
1580
+
1581
+ # calculate gradient with respect to the mean
1597
1582
  grad_mu = jax.tree_util.tree_map(
1598
1583
  partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1599
1584
  epsilon, epsilon_star
1600
1585
  )
1586
+
1587
+ # calculate gradient with respect to the sigma
1601
1588
  grad_sigma = jax.tree_util.tree_map(
1602
1589
  partial(_jax_wrapped_sigma_grad,
1603
1590
  r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
@@ -1605,21 +1592,30 @@ class GaussianPGPE(PGPE):
1605
1592
  )
1606
1593
  return grad_mu, grad_sigma, r_max
1607
1594
 
1595
+ # calculate the policy gradients with batching on the first dimension
1608
1596
  def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
1609
- policy_hyperparams, subs, model_params):
1597
+ policy_hyperparams, fls, nfls, model_params):
1610
1598
  mu, sigma = pgpe_params
1599
+
1600
+ # no batching required
1611
1601
  if batch_size == 1:
1612
1602
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1613
- key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1603
+ key, mu, sigma, r_max, ent, policy_hyperparams, fls, nfls, model_params)
1604
+
1605
+ # for batching need to handle how meta-gradients of mean, sigma are aggregated
1614
1606
  else:
1607
+ # do the batched calculation of mean and sigma gradients
1615
1608
  keys = random.split(key, num=batch_size)
1616
1609
  mu_grads, sigma_grads, r_maxs = jax.vmap(
1617
1610
  _jax_wrapped_pgpe_grad,
1618
- in_axes=(0, None, None, None, None, None, None, None)
1619
- )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1611
+ in_axes=(0, None, None, None, None, None, None, None, None)
1612
+ )(keys, mu, sigma, r_max, ent, policy_hyperparams, fls, nfls, model_params)
1613
+
1614
+ # calculate the average gradient for aggregation
1620
1615
  mu_grad, sigma_grad = jax.tree_util.tree_map(
1621
1616
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1622
1617
  new_r_max = jnp.max(r_maxs)
1618
+
1623
1619
  return mu_grad, sigma_grad, new_r_max
1624
1620
 
1625
1621
  # ***********************************************************************
@@ -1646,17 +1642,16 @@ class GaussianPGPE(PGPE):
1646
1642
  return new_mu, new_sigma, new_mu_state, new_sigma_state
1647
1643
 
1648
1644
  def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1649
- policy_hyperparams, subs, model_params,
1645
+ policy_hyperparams, fls, nfls, model_params,
1650
1646
  pgpe_opt_state):
1651
- # regular update
1647
+ # regular update for pgpe
1652
1648
  mu, sigma = pgpe_params
1653
1649
  mu_state, sigma_state = pgpe_opt_state
1654
1650
  ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1655
1651
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1656
- key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1657
- new_mu, new_sigma, new_mu_state, new_sigma_state = \
1658
- _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1659
- mu_state, sigma_state)
1652
+ key, pgpe_params, r_max, ent, policy_hyperparams, fls, nfls, model_params)
1653
+ new_mu, new_sigma, new_mu_state, new_sigma_state = _jax_wrapped_pgpe_update_helper(
1654
+ mu, sigma, mu_grad, sigma_grad, mu_state, sigma_state)
1660
1655
 
1661
1656
  # respect KL divergence contraint with old parameters
1662
1657
  if max_kl is not None:
@@ -1668,34 +1663,30 @@ class GaussianPGPE(PGPE):
1668
1663
  kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1669
1664
  mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1670
1665
  sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1671
- new_mu, new_sigma, new_mu_state, new_sigma_state = \
1672
- _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1673
- mu_state, sigma_state)
1666
+ new_mu, new_sigma, new_mu_state, new_sigma_state = _jax_wrapped_pgpe_update_helper(
1667
+ mu, sigma, mu_grad, sigma_grad, mu_state, sigma_state)
1674
1668
  new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1675
1669
  new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1676
1670
 
1677
- # apply projection step and finalize results
1671
+ # apply projection step to the sampled policy
1678
1672
  new_mu, converged = projection(new_mu, policy_hyperparams)
1673
+
1679
1674
  new_pgpe_params = (new_mu, new_sigma)
1680
1675
  new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1681
1676
  policy_params = new_mu
1682
1677
  return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1683
1678
 
1684
- if parallel_updates is None:
1685
- self._update = jax.jit(_jax_wrapped_pgpe_update)
1686
- else:
1687
-
1688
- # for parallel policy update
1689
- def _jax_wrapped_pgpe_updates(key, pgpe_params, r_max, progress,
1690
- policy_hyperparams, subs, model_params,
1691
- pgpe_opt_state):
1692
- keys = jnp.asarray(random.split(key, num=parallel_updates))
1693
- return jax.vmap(
1694
- _jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, 0, 0)
1695
- )(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
1696
- model_params, pgpe_opt_state)
1697
-
1698
- self._update = jax.jit(_jax_wrapped_pgpe_updates)
1679
+ # for parallel policy update
1680
+ def _jax_wrapped_batched_pgpe_updates(key, pgpe_params, r_max, progress,
1681
+ policy_hyperparams, fls, nfls, model_params,
1682
+ pgpe_opt_state):
1683
+ keys = random.split(key, num=parallel_updates)
1684
+ return jax.vmap(
1685
+ _jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, None, 0, 0)
1686
+ )(keys, pgpe_params, r_max, progress, policy_hyperparams, fls, nfls,
1687
+ model_params, pgpe_opt_state)
1688
+
1689
+ self._update = jax.jit(_jax_wrapped_batched_pgpe_updates)
1699
1690
 
1700
1691
 
1701
1692
  # ***********************************************************************
@@ -1757,7 +1748,7 @@ def var_utility(returns: jnp.ndarray, alpha: float) -> float:
1757
1748
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1758
1749
  var = jnp.percentile(returns, q=100 * alpha)
1759
1750
  mask = returns <= var
1760
- return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
1751
+ return jnp.sum(returns * mask) / jnp.maximum(1, jnp.count_nonzero(mask))
1761
1752
 
1762
1753
 
1763
1754
  # set of all currently valid built-in utility functions
@@ -1783,6 +1774,11 @@ UTILITY_LOOKUP = {
1783
1774
  # ***********************************************************************
1784
1775
 
1785
1776
 
1777
+ @jax.jit
1778
+ def pytree_at(tree: Pytree, i: int) -> Pytree:
1779
+ return jax.tree_util.tree_map(lambda x: x[i], tree)
1780
+
1781
+
1786
1782
  class JaxBackpropPlanner:
1787
1783
  '''A class for optimizing an action sequence in the given RDDL MDP using
1788
1784
  gradient descent.'''
@@ -1792,30 +1788,29 @@ class JaxBackpropPlanner:
1792
1788
  batch_size_train: int=32,
1793
1789
  batch_size_test: Optional[int]=None,
1794
1790
  rollout_horizon: Optional[int]=None,
1795
- use64bit: bool=False,
1791
+ parallel_updates: int=1,
1796
1792
  action_bounds: Optional[Bounds]=None,
1797
1793
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1798
1794
  optimizer_kwargs: Optional[Kwargs]=None,
1799
1795
  clip_grad: Optional[float]=None,
1800
1796
  line_search_kwargs: Optional[Kwargs]=None,
1801
1797
  noise_kwargs: Optional[Kwargs]=None,
1798
+ ema_decay: Optional[float]=None,
1802
1799
  pgpe: Optional[PGPE]=GaussianPGPE(),
1803
- logic: Logic=FuzzyLogic(),
1800
+ compiler: JaxRDDLCompilerWithGrad=DefaultJaxRDDLCompilerWithGrad,
1801
+ compiler_kwargs: Optional[Kwargs]=None,
1804
1802
  use_symlog_reward: bool=False,
1805
1803
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1806
1804
  utility_kwargs: Optional[Kwargs]=None,
1807
- cpfs_without_grad: Optional[Set[str]]=None,
1808
- compile_non_fluent_exact: bool=True,
1809
1805
  logger: Optional[Logger]=None,
1806
+ dashboard: Optional[Any]=None,
1810
1807
  dashboard_viz: Optional[Any]=None,
1811
- print_warnings: bool=True,
1812
- parallel_updates: Optional[int]=None,
1813
1808
  preprocessor: Optional[Preprocessor]=None,
1814
1809
  python_functions: Optional[Dict[str, Callable]]=None) -> None:
1815
1810
  '''Creates a new gradient-based algorithm for optimizing action sequences
1816
1811
  (plan) in the given RDDL. Some operations will be converted to their
1817
1812
  differentiable counterparts; the specific operations can be customized
1818
- by providing a subclass of FuzzyLogic.
1813
+ by providing a tailored compiler instance.
1819
1814
 
1820
1815
  :param rddl: the RDDL domain to optimize
1821
1816
  :param plan: the policy/plan representation to optimize
@@ -1823,9 +1818,8 @@ class JaxBackpropPlanner:
1823
1818
  step
1824
1819
  :param batch_size_test: how many rollouts to use to test the plan at each
1825
1820
  optimization step
1826
- :param rollout_horizon: lookahead planning horizon: None uses the
1827
- :param use64bit: whether to perform arithmetic in 64 bit
1828
- horizon parameter in the RDDL instance
1821
+ :param rollout_horizon: lookahead planning horizon: None uses the env horizon
1822
+ :param parallel_updates: how many optimizers to run independently in parallel
1829
1823
  :param action_bounds: box constraints on actions
1830
1824
  :param optimizer: a factory for an optax SGD algorithm
1831
1825
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
@@ -1834,9 +1828,10 @@ class JaxBackpropPlanner:
1834
1828
  :param line_search_kwargs: parameters to pass to optional line search
1835
1829
  method to scale learning rate
1836
1830
  :param noise_kwargs: parameters of optional gradient noise
1831
+ :param ema_decay: optional exponential moving average of past parameters
1837
1832
  :param pgpe: optional policy gradient to run alongside the planner
1838
- :param logic: a subclass of Logic for mapping exact mathematical
1839
- operations to their differentiable counterparts
1833
+ :param compiler: compiler instance to use for planning
1834
+ :param compiler_kwargs: compiler instances kwargs for initialization
1840
1835
  :param use_symlog_reward: whether to use the symlog transform on the
1841
1836
  reward as a form of normalization
1842
1837
  :param utility: how to aggregate return observations to compute utility
@@ -1844,15 +1839,10 @@ class JaxBackpropPlanner:
1844
1839
  scalar, or a a string identifying the utility function by name
1845
1840
  :param utility_kwargs: additional keyword arguments to pass hyper-
1846
1841
  parameters to the utility function call
1847
- :param cpfs_without_grad: which CPFs do not have gradients (use straight
1848
- through gradient trick)
1849
- :param compile_non_fluent_exact: whether non-fluent expressions
1850
- are always compiled using exact JAX expressions
1851
1842
  :param logger: to log information about compilation to file
1843
+ :param dashboard: optional dashboard to display training progress and results
1852
1844
  :param dashboard_viz: optional visualizer object from the environment
1853
1845
  to pass to the dashboard to visualize the policy
1854
- :param print_warnings: whether to print warnings
1855
- :param parallel_updates: how many optimizers to run independently in parallel
1856
1846
  :param preprocessor: optional preprocessor for state inputs to plan
1857
1847
  :param python_functions: dictionary of external Python functions to call from RDDL
1858
1848
  '''
@@ -1869,7 +1859,6 @@ class JaxBackpropPlanner:
1869
1859
  if action_bounds is None:
1870
1860
  action_bounds = {}
1871
1861
  self._action_bounds = action_bounds
1872
- self.use64bit = use64bit
1873
1862
  self.optimizer_name = optimizer
1874
1863
  if optimizer_kwargs is None:
1875
1864
  optimizer_kwargs = {'learning_rate': 0.1}
@@ -1877,9 +1866,9 @@ class JaxBackpropPlanner:
1877
1866
  self.clip_grad = clip_grad
1878
1867
  self.line_search_kwargs = line_search_kwargs
1879
1868
  self.noise_kwargs = noise_kwargs
1869
+ self.ema_decay = ema_decay
1880
1870
  self.pgpe = pgpe
1881
1871
  self.use_pgpe = pgpe is not None
1882
- self.print_warnings = print_warnings
1883
1872
  self.preprocessor = preprocessor
1884
1873
  if python_functions is None:
1885
1874
  python_functions = {}
@@ -1889,11 +1878,10 @@ class JaxBackpropPlanner:
1889
1878
  try:
1890
1879
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1891
1880
  except Exception as _:
1892
- message = termcolor.colored(
1893
- '[FAIL] Failed to inject hyperparameters into JaxPlan optimizer, '
1894
- 'rolling back to safer method: please note that runtime modification of '
1895
- 'hyperparameters will be disabled.', 'red')
1896
- print(message)
1881
+ print(termcolor.colored(
1882
+ '[WARN] Could not inject hyperparameters into JaxPlan optimizer: '
1883
+ 'runtime modification of hyperparameters will be disabled.', 'yellow'
1884
+ ))
1897
1885
  optimizer = optimizer(**optimizer_kwargs)
1898
1886
 
1899
1887
  # apply optimizer chain of transformations
@@ -1905,6 +1893,8 @@ class JaxBackpropPlanner:
1905
1893
  pipeline.append(optimizer)
1906
1894
  if line_search_kwargs is not None:
1907
1895
  pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
1896
+ if ema_decay is not None:
1897
+ pipeline.append(optax.ema(ema_decay))
1908
1898
  self.optimizer = optax.chain(*pipeline)
1909
1899
 
1910
1900
  # set utility
@@ -1914,99 +1904,75 @@ class JaxBackpropPlanner:
1914
1904
  if utility_fn is None:
1915
1905
  raise RDDLNotImplementedError(
1916
1906
  f'Utility function <{utility}> is not supported, '
1917
- f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1907
+ f'must be one of {list(UTILITY_LOOKUP.keys())}.'
1908
+ )
1918
1909
  else:
1919
1910
  utility_fn = utility
1920
1911
  self.utility = utility_fn
1921
-
1922
1912
  if utility_kwargs is None:
1923
1913
  utility_kwargs = {}
1924
1914
  self.utility_kwargs = utility_kwargs
1925
1915
 
1926
- self.logic = logic
1927
- self.logic.set_use64bit(self.use64bit)
1916
+ if compiler_kwargs is None:
1917
+ compiler_kwargs = {}
1918
+ self.compiler_type = compiler
1919
+ self.compiler_kwargs = compiler_kwargs
1928
1920
  self.use_symlog_reward = use_symlog_reward
1929
- if cpfs_without_grad is None:
1930
- cpfs_without_grad = set()
1931
- self.cpfs_without_grad = cpfs_without_grad
1932
- self.compile_non_fluent_exact = compile_non_fluent_exact
1921
+
1933
1922
  self.logger = logger
1923
+ self.dashboard = dashboard
1934
1924
  self.dashboard_viz = dashboard_viz
1935
1925
 
1936
- self._jax_compile_rddl()
1937
- self._jax_compile_optimizer()
1926
+ self._jax_compile_graph()
1938
1927
 
1939
1928
  @staticmethod
1940
1929
  def summarize_system() -> str:
1941
1930
  '''Returns a string containing information about the system, Python version
1942
1931
  and jax-related packages that are relevant to the current planner.
1943
- '''
1944
- try:
1945
- jaxlib_version = jax._src.lib.version_str
1946
- except Exception as _:
1947
- jaxlib_version = 'N/A'
1948
- try:
1949
- devices_short = ', '.join(
1950
- map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1951
- except Exception as _:
1952
- devices_short = 'N/A'
1953
- LOGO = \
1954
- r"""
1955
- __ ______ __ __ ______ __ ______ __ __
1956
- /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1957
- _\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
1958
- /\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
1959
- \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1960
- """
1961
-
1962
- return (f'\n'
1963
- f'{LOGO}\n'
1964
- f'Version {__version__}\n'
1965
- f'Python {sys.version}\n'
1966
- f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1967
- f'optax {optax.__version__}, haiku {hk.__version__}, '
1968
- f'numpy {np.__version__}\n'
1969
- f'devices: {devices_short}\n')
1932
+ '''
1933
+ devices = jax.devices()
1934
+ default_device = devices[0] if devices else 'n/a'
1935
+ return termcolor.colored(
1936
+ '\n'
1937
+ f'Starting JaxPlan v{__version__} '
1938
+ f'on device {default_device.platform}{default_device.id}\n', attrs=['bold']
1939
+ )
1970
1940
 
1971
1941
  def summarize_relaxations(self) -> str:
1972
1942
  '''Returns a summary table containing all non-differentiable operators
1973
1943
  and their relaxations.
1974
1944
  '''
1975
1945
  result = ''
1976
- if self.compiled.model_params:
1977
- result += ('Some RDDL operations are non-differentiable '
1946
+ overriden_ops_info = self.compiled.overriden_ops_info()
1947
+ if overriden_ops_info:
1948
+ result += ('[INFO] Some RDDL operations are non-differentiable '
1978
1949
  'and will be approximated as follows:' + '\n')
1979
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1980
- for info in self.compiled.model_parameter_info().values():
1981
- rddl_op = info['rddl_op']
1982
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1983
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1984
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1985
- result += (f' {rddl_op}:\n'
1986
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1987
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1950
+ for (class_, op_to_ids_dict) in overriden_ops_info.items():
1951
+ result += f' {class_}:\n'
1952
+ for (op, ids) in op_to_ids_dict.items():
1953
+ result += (
1954
+ f' {op} ' +
1955
+ termcolor.colored(f'[{len(ids)} occurences]\n', 'dark_grey')
1956
+ )
1988
1957
  return result
1989
1958
 
1990
1959
  def summarize_hyperparameters(self) -> str:
1991
1960
  '''Returns a string summarizing the hyper-parameters of the current planner
1992
1961
  instance.
1993
1962
  '''
1994
- result = (f'objective hyper-parameters:\n'
1963
+ result = (f'[INFO] objective hyper-parameters:\n'
1995
1964
  f' utility_fn ={self.utility.__name__}\n'
1996
1965
  f' utility args ={self.utility_kwargs}\n'
1997
1966
  f' use_symlog ={self.use_symlog_reward}\n'
1998
1967
  f' lookahead ={self.horizon}\n'
1999
1968
  f' user_action_bounds={self._action_bounds}\n'
2000
- f' fuzzy logic type ={type(self.logic).__name__}\n'
2001
- f' non_fluents exact ={self.compile_non_fluent_exact}\n'
2002
- f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
2003
- f'optimizer hyper-parameters:\n'
2004
- f' use_64_bit ={self.use64bit}\n'
1969
+ f'[INFO] optimizer hyper-parameters:\n'
2005
1970
  f' optimizer ={self.optimizer_name}\n'
2006
1971
  f' optimizer args ={self.optimizer_kwargs}\n'
2007
1972
  f' clip_gradient ={self.clip_grad}\n'
2008
1973
  f' line_search_kwargs={self.line_search_kwargs}\n'
2009
1974
  f' noise_kwargs ={self.noise_kwargs}\n'
1975
+ f' ema_decay ={self.ema_decay}\n'
2010
1976
  f' batch_size_train ={self.batch_size_train}\n'
2011
1977
  f' batch_size_test ={self.batch_size_test}\n'
2012
1978
  f' parallel_updates ={self.parallel_updates}\n'
@@ -2014,89 +1980,78 @@ r"""
2014
1980
  result += str(self.plan)
2015
1981
  if self.use_pgpe:
2016
1982
  result += str(self.pgpe)
2017
- result += str(self.logic)
1983
+ result += 'test compiler:\n'
1984
+ for k, v in self.test_compiled.get_kwargs().items():
1985
+ result += f' {k}={v}\n'
1986
+ result += 'train compiler:\n'
1987
+ for k, v in self.compiled.get_kwargs().items():
1988
+ result += f' {k}={v}\n'
2018
1989
  return result
2019
1990
 
2020
1991
  # ===========================================================================
2021
- # COMPILATION SUBROUTINES
1992
+ # COMPILE RDDL
2022
1993
  # ===========================================================================
2023
1994
 
2024
1995
  def _jax_compile_rddl(self):
2025
- rddl = self.rddl
2026
-
2027
- # Jax compilation of the differentiable RDDL for training
2028
- self.compiled = JaxRDDLCompilerWithGrad(
2029
- rddl=rddl,
2030
- logic=self.logic,
1996
+ self.compiled = self.compiler_type(
1997
+ rddl=self.rddl,
2031
1998
  logger=self.logger,
2032
- use64bit=self.use64bit,
2033
- cpfs_without_grad=self.cpfs_without_grad,
2034
- compile_non_fluent_exact=self.compile_non_fluent_exact,
2035
- print_warnings=self.print_warnings,
2036
- python_functions=self.python_functions
1999
+ python_functions=self.python_functions,
2000
+ **self.compiler_kwargs
2037
2001
  )
2002
+ self.print_warnings = self.compiled.print_warnings
2038
2003
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
2039
-
2040
- # Jax compilation of the exact RDDL for testing
2004
+
2041
2005
  self.test_compiled = JaxRDDLCompiler(
2042
- rddl=rddl,
2006
+ rddl=self.rddl,
2007
+ allow_synchronous_state=True,
2043
2008
  logger=self.logger,
2044
- use64bit=self.use64bit,
2009
+ use64bit=self.compiled.use64bit,
2045
2010
  python_functions=self.python_functions
2046
2011
  )
2047
2012
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
2048
-
2049
- def _jax_compile_optimizer(self):
2050
-
2051
- # preprocessor
2013
+
2014
+ def _jax_compile_policy(self):
2052
2015
  if self.preprocessor is not None:
2053
2016
  self.preprocessor.compile(self.compiled)
2054
-
2055
- # policy
2056
2017
  self.plan.compile(self.compiled,
2057
2018
  _bounds=self._action_bounds,
2058
2019
  horizon=self.horizon,
2059
2020
  preprocessor=self.preprocessor)
2060
2021
  self.train_policy = jax.jit(self.plan.train_policy)
2061
2022
  self.test_policy = jax.jit(self.plan.test_policy)
2062
-
2063
- # roll-outs
2064
- train_rollouts = self.compiled.compile_rollouts(
2023
+
2024
+ def _jax_compile_rollouts(self):
2025
+ self.train_rollouts = self.compiled.compile_rollouts(
2065
2026
  policy=self.plan.train_policy,
2066
2027
  n_steps=self.horizon,
2067
2028
  n_batch=self.batch_size_train,
2068
- cache_path_info=self.preprocessor is not None
2029
+ cache_path_info=self.preprocessor is not None or self.dashboard is not None
2069
2030
  )
2070
- self.train_rollouts = train_rollouts
2071
-
2072
2031
  test_rollouts = self.test_compiled.compile_rollouts(
2073
2032
  policy=self.plan.test_policy,
2074
2033
  n_steps=self.horizon,
2075
2034
  n_batch=self.batch_size_test,
2076
- cache_path_info=False
2035
+ cache_path_info=self.dashboard is not None
2077
2036
  )
2078
2037
  self.test_rollouts = jax.jit(test_rollouts)
2079
-
2080
- # initialization
2081
- self.initialize, self.init_optimizer = self._jax_init()
2082
-
2083
- # losses
2084
- train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
2085
- test_loss = self._jax_loss(test_rollouts, use_symlog=False)
2086
- if self.parallel_updates is None:
2087
- self.test_loss = jax.jit(test_loss)
2088
- else:
2089
- self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
2090
-
2091
- # optimization
2038
+
2039
+ def _jax_compile_train_update(self):
2040
+ self.initialize, self.init_optimizer = self._jax_init_optimizer()
2041
+ train_loss = self._jax_loss(self.train_rollouts, use_symlog=self.use_symlog_reward)
2042
+ self.single_train_loss = train_loss
2092
2043
  self.update = self._jax_update(train_loss)
2093
- self.pytree_at = jax.jit(
2094
- lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
2044
+
2045
+ def _jax_compile_test_loss(self):
2046
+ test_loss = self._jax_loss(self.test_rollouts, use_symlog=False)
2047
+ self.single_test_loss = test_loss
2048
+ self.test_loss = jax.jit(jax.vmap(
2049
+ test_loss, in_axes=(None, 0, None, None, None, 0)))
2095
2050
 
2096
- # pgpe option
2051
+ def _jax_compile_pgpe(self):
2097
2052
  if self.use_pgpe:
2098
2053
  self.pgpe.compile(
2099
- loss_fn=test_loss,
2054
+ loss_fn=self.single_test_loss,
2100
2055
  projection=self.plan.projection,
2101
2056
  real_dtype=self.test_compiled.REAL,
2102
2057
  print_warnings=self.print_warnings,
@@ -2106,6 +2061,14 @@ r"""
2106
2061
  else:
2107
2062
  self.merge_pgpe = None
2108
2063
 
2064
+ def _jax_compile_graph(self):
2065
+ self._jax_compile_rddl()
2066
+ self._jax_compile_policy()
2067
+ self._jax_compile_rollouts()
2068
+ self._jax_compile_train_update()
2069
+ self._jax_compile_test_loss()
2070
+ self._jax_compile_pgpe()
2071
+
2109
2072
  def _jax_return(self, use_symlog):
2110
2073
  gamma = self.rddl.discount
2111
2074
 
@@ -2117,9 +2080,8 @@ r"""
2117
2080
  rewards = rewards * discount[jnp.newaxis, ...]
2118
2081
  returns = jnp.sum(rewards, axis=1)
2119
2082
  if use_symlog:
2120
- returns = jnp.sign(returns) * jnp.log(1.0 + jnp.abs(returns))
2083
+ returns = jnp.sign(returns) * jnp.log1p(jnp.abs(returns))
2121
2084
  return returns
2122
-
2123
2085
  return _jax_wrapped_returns
2124
2086
 
2125
2087
  def _jax_loss(self, rollouts, use_symlog=False):
@@ -2128,48 +2090,44 @@ r"""
2128
2090
  _jax_wrapped_returns = self._jax_return(use_symlog)
2129
2091
 
2130
2092
  # the loss is the average cumulative reward across all roll-outs
2131
- def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
2132
- subs, model_params):
2093
+ # but applies a utility function if requested to each return observation:
2094
+ # by default, the utility function is the mean
2095
+ def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams, fls, nfls,
2096
+ model_params):
2133
2097
  log, model_params = rollouts(
2134
- key, policy_params, policy_hyperparams, subs, model_params)
2098
+ key, policy_params, policy_hyperparams, fls, nfls, model_params)
2135
2099
  rewards = log['reward']
2136
2100
  returns = _jax_wrapped_returns(rewards)
2137
2101
  utility = utility_fn(returns, **utility_kwargs)
2138
2102
  loss = -utility
2139
2103
  aux = (log, model_params)
2140
2104
  return loss, aux
2141
-
2142
2105
  return _jax_wrapped_plan_loss
2143
2106
 
2144
- def _jax_init(self):
2107
+ def _jax_init_optimizer(self):
2145
2108
  init = self.plan.initializer
2146
2109
  optimizer = self.optimizer
2147
2110
  num_parallel = self.parallel_updates
2148
2111
 
2149
2112
  # initialize both the policy and its optimizer
2150
- def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
2151
- policy_params = init(key, policy_hyperparams, subs)
2113
+ def _jax_wrapped_init_policy(key, policy_hyperparams, fls):
2114
+ policy_params = init(key, policy_hyperparams, fls)
2152
2115
  opt_state = optimizer.init(policy_params)
2153
2116
  return policy_params, opt_state, {}
2154
2117
 
2155
2118
  # initialize just the optimizer from the policy
2156
2119
  def _jax_wrapped_init_opt(policy_params):
2157
- if num_parallel is None:
2158
- opt_state = optimizer.init(policy_params)
2159
- else:
2160
- opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
2120
+ opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
2161
2121
  return opt_state, {}
2162
2122
 
2163
- if num_parallel is None:
2164
- return jax.jit(_jax_wrapped_init_policy), jax.jit(_jax_wrapped_init_opt)
2165
-
2166
- # for parallel policy update
2167
- def _jax_wrapped_init_policies(key, policy_hyperparams, subs):
2168
- keys = jnp.asarray(random.split(key, num=num_parallel))
2169
- return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
2170
- keys, policy_hyperparams, subs)
2123
+ # initialize multiple policies to be optimized in parallel
2124
+ def _jax_wrapped_batched_init_policy(key, policy_hyperparams, fls):
2125
+ keys = random.split(key, num=num_parallel)
2126
+ return jax.vmap(
2127
+ _jax_wrapped_init_policy, in_axes=(0, None, None)
2128
+ )(keys, policy_hyperparams, fls)
2171
2129
 
2172
- return jax.jit(_jax_wrapped_init_policies), jax.jit(_jax_wrapped_init_opt)
2130
+ return jax.jit(_jax_wrapped_batched_init_policy), jax.jit(_jax_wrapped_init_opt)
2173
2131
 
2174
2132
  def _jax_update(self, loss):
2175
2133
  optimizer = self.optimizer
@@ -2185,114 +2143,121 @@ r"""
2185
2143
 
2186
2144
  # calculate the plan gradient w.r.t. return loss and update optimizer
2187
2145
  # also perform a projection step to satisfy constraints on actions
2188
- def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
2189
- subs, model_params):
2190
- return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
2146
+ def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams, fls, nfls,
2147
+ model_params):
2148
+ return loss(key, policy_params, policy_hyperparams, fls, nfls, model_params)[0]
2191
2149
 
2192
- def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
2193
- subs, model_params, opt_state, opt_aux):
2150
+ def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams, fls, nfls,
2151
+ model_params, opt_state, opt_aux):
2152
+
2153
+ # calculate the gradient of the loss with respect to the policy
2194
2154
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
2195
2155
  (loss_val, (log, model_params)), grad = grad_fn(
2196
- key, policy_params, policy_hyperparams, subs, model_params)
2156
+ key, policy_params, policy_hyperparams, fls, nfls, model_params)
2157
+
2158
+ # require a slightly different update if line search is used
2197
2159
  if use_ls:
2198
2160
  updates, opt_state = optimizer.update(
2199
2161
  grad, opt_state, params=policy_params,
2200
2162
  value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
2201
- key=key, policy_hyperparams=policy_hyperparams, subs=subs,
2202
- model_params=model_params)
2163
+ key=key, policy_hyperparams=policy_hyperparams, fls=fls, nfls=nfls,
2164
+ model_params=model_params
2165
+ )
2203
2166
  else:
2204
- updates, opt_state = optimizer.update(
2205
- grad, opt_state, params=policy_params)
2167
+ updates, opt_state = optimizer.update(grad, opt_state, params=policy_params)
2168
+
2169
+ # apply optimizer and optional policy projection
2206
2170
  policy_params = optax.apply_updates(policy_params, updates)
2207
2171
  policy_params, converged = projection(policy_params, policy_hyperparams)
2172
+
2208
2173
  log['grad'] = grad
2209
2174
  log['updates'] = updates
2210
2175
  zero_grads = _jax_wrapped_zero_gradients(grad)
2211
- return policy_params, converged, opt_state, opt_aux, \
2212
- loss_val, log, model_params, zero_grads
2176
+ return (policy_params, converged, opt_state, opt_aux,
2177
+ loss_val, log, model_params, zero_grads)
2213
2178
 
2214
- if num_parallel is None:
2215
- return jax.jit(_jax_wrapped_plan_update)
2216
-
2217
- # for parallel policy update
2218
- def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
2219
- subs, model_params, opt_state, opt_aux):
2220
- keys = jnp.asarray(random.split(key, num=num_parallel))
2179
+ # for parallel policy update, just do each policy update in parallel
2180
+ def _jax_wrapped_batched_plan_update(key, policy_params, policy_hyperparams,
2181
+ fls, nfls, model_params, opt_state, opt_aux):
2182
+ keys = random.split(key, num=num_parallel)
2221
2183
  return jax.vmap(
2222
- _jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
2223
- )(keys, policy_params, policy_hyperparams, subs, model_params,
2184
+ _jax_wrapped_plan_update, in_axes=(0, 0, None, None, None, 0, 0, 0)
2185
+ )(keys, policy_params, policy_hyperparams, fls, nfls, model_params,
2224
2186
  opt_state, opt_aux)
2225
-
2226
- return jax.jit(_jax_wrapped_plan_updates)
2187
+ return jax.jit(_jax_wrapped_batched_plan_update)
2227
2188
 
2228
2189
  def _jax_merge_pgpe_jaxplan(self):
2229
- if self.parallel_updates is None:
2230
- return None
2231
-
2232
- # for parallel policy update
2190
+
2233
2191
  # currently implements a hard replacement where the jaxplan parameter
2234
2192
  # is replaced by the PGPE parameter if the latter is an improvement
2235
- def _jax_wrapped_pgpe_jaxplan_merge(pgpe_mask, pgpe_param, policy_params,
2193
+ def _jax_wrapped_batched_pgpe_merge(pgpe_mask, pgpe_param, policy_params,
2236
2194
  pgpe_loss, test_loss,
2237
2195
  pgpe_loss_smooth, test_loss_smooth,
2238
2196
  pgpe_converged, converged):
2239
- def select_fn(leaf1, leaf2):
2240
- expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
2241
- return jnp.where(expanded_mask, leaf1, leaf2)
2242
- policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
2197
+ mask_tree = jax.tree_util.tree_map(
2198
+ lambda leaf: pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf) - 1)],
2199
+ pgpe_param)
2200
+ policy_params = jax.tree_util.tree_map(
2201
+ jnp.where, mask_tree, pgpe_param, policy_params)
2243
2202
  test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
2244
2203
  test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
2245
- expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
2246
- converged = jnp.where(expanded_mask, pgpe_converged, converged)
2204
+ converged = jnp.where(pgpe_mask, pgpe_converged, converged)
2247
2205
  return policy_params, test_loss, test_loss_smooth, converged
2206
+ return jax.jit(_jax_wrapped_batched_pgpe_merge)
2248
2207
 
2249
- return jax.jit(_jax_wrapped_pgpe_jaxplan_merge)
2250
-
2251
- def _batched_init_subs(self, subs):
2208
+ def _batched_init_subs(self, init_values):
2252
2209
  rddl = self.rddl
2253
2210
  n_train, n_test = self.batch_size_train, self.batch_size_test
2254
2211
 
2255
- # batched subs
2256
- init_train, init_test = {}, {}
2257
- for (name, value) in subs.items():
2212
+ init_train_fls, init_train_nfls, init_test_fls, init_test_nfls = {}, {}, {}, {}
2213
+ for (name, value) in init_values.items():
2214
+
2215
+ # get the initial fluent values and check validity
2258
2216
  init_value = self.test_compiled.init_values.get(name, None)
2259
2217
  if init_value is None:
2260
2218
  raise RDDLUndefinedVariableError(
2261
- f'Variable <{name}> in subs argument is not a '
2219
+ f'Variable <{name}> in init_values argument is not a '
2262
2220
  f'valid p-variable, must be one of '
2263
- f'{set(self.test_compiled.init_values.keys())}.')
2264
- value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
2221
+ f'{set(self.test_compiled.init_values.keys())}.'
2222
+ )
2223
+
2224
+ # for enum types need to convert the string values to integer indices
2225
+ if np.size(value) != np.size(init_value):
2226
+ value = init_value
2227
+ value = np.reshape(value, np.shape(init_value))
2265
2228
  if value.dtype.type is np.str_:
2266
2229
  value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
2267
- train_value = np.repeat(value, repeats=n_train, axis=0)
2268
- train_value = np.asarray(train_value, dtype=self.compiled.REAL)
2269
- init_train[name] = train_value
2270
- init_test[name] = np.repeat(value, repeats=n_test, axis=0)
2230
+
2231
+ # train and test fluents have a batch dimension added, non-fluents do not
2232
+ # train fluents are also converted to float
2233
+ if name not in rddl.non_fluents:
2234
+ train_value = np.repeat(value[np.newaxis, ...], repeats=n_train, axis=0)
2235
+ init_train_fls[name] = np.asarray(train_value, dtype=self.compiled.REAL)
2236
+ init_test_fls[name] = np.repeat(value[np.newaxis, ...], repeats=n_test, axis=0)
2237
+ else:
2238
+ init_train_nfls[name] = np.asarray(value, dtype=self.compiled.REAL)
2239
+ init_test_nfls[name] = value
2271
2240
 
2272
- # safely cast test subs variable to required type in case the type is wrong
2241
+ # safely cast test variable to required type in case the type is wrong
2273
2242
  if name in rddl.variable_ranges:
2274
2243
  required_type = RDDLValueInitializer.NUMPY_TYPES.get(
2275
2244
  rddl.variable_ranges[name], RDDLValueInitializer.INT)
2245
+ init_test = init_test_nfls if name in rddl.non_fluents else init_test_fls
2276
2246
  if np.result_type(init_test[name]) != required_type:
2277
2247
  init_test[name] = np.asarray(init_test[name], dtype=required_type)
2278
2248
 
2279
2249
  # make sure next-state fluents are also set
2280
2250
  for (state, next_state) in rddl.next_state.items():
2281
- init_train[next_state] = init_train[state]
2282
- init_test[next_state] = init_test[state]
2283
- return init_train, init_test
2251
+ init_train_fls[next_state] = init_train_fls[state]
2252
+ init_test_fls[next_state] = init_test_fls[state]
2253
+ return (init_train_fls, init_train_nfls), (init_test_fls, init_test_nfls)
2284
2254
 
2285
2255
  def _broadcast_pytree(self, pytree):
2286
- if self.parallel_updates is None:
2287
- return pytree
2288
-
2289
- # for parallel policy update
2290
2256
  def make_batched(x):
2291
2257
  x = np.asarray(x)
2292
2258
  x = np.broadcast_to(
2293
2259
  x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
2294
2260
  return x
2295
-
2296
2261
  return jax.tree_util.tree_map(make_batched, pytree)
2297
2262
 
2298
2263
  def as_optimization_problem(
@@ -2324,7 +2289,7 @@ r"""
2324
2289
  '''
2325
2290
 
2326
2291
  # make sure parallel updates are disabled
2327
- if self.parallel_updates is not None:
2292
+ if self.parallel_updates > 1:
2328
2293
  raise ValueError('Cannot compile static optimization problem '
2329
2294
  'when parallel_updates is not None.')
2330
2295
 
@@ -2333,42 +2298,45 @@ r"""
2333
2298
  key = random.PRNGKey(round(time.time() * 1000))
2334
2299
 
2335
2300
  # initialize the initial fluents, model parameters, policy hyper-params
2336
- subs = self.test_compiled.init_values
2337
- train_subs, _ = self._batched_init_subs(subs)
2338
- model_params = self.compiled.model_params
2301
+ (fls, nfls), _ = self._batched_init_subs(self.test_compiled.init_values)
2302
+ model_params = self.compiled.model_aux['params']
2339
2303
  if policy_hyperparams is None:
2340
2304
  if self.print_warnings:
2341
- message = termcolor.colored(
2342
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2343
- 'all action-fluents which could be suboptimal.', 'yellow')
2344
- print(message)
2345
- policy_hyperparams = {action: 1.0
2346
- for action in self.rddl.action_fluents}
2305
+ print(termcolor.colored(
2306
+ '[WARN] policy_hyperparams is not set: setting values to 1.0 for '
2307
+ 'all action-fluents, which could be suboptimal.', 'yellow'
2308
+ ))
2309
+ policy_hyperparams = {action: 1. for action in self.rddl.action_fluents}
2347
2310
 
2348
2311
  # initialize the policy parameters
2349
- params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
2312
+ params_guess = self.initialize(key, policy_hyperparams, fls)[0]
2313
+ params_guess = pytree_at(params_guess, 0)
2314
+
2315
+ # get the params mapping to a 1D vector
2350
2316
  guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
2351
2317
  guess_1d = np.asarray(guess_1d)
2352
2318
 
2353
- # computes the training loss function and its 1D gradient
2354
- loss_fn = self._jax_loss(self.train_rollouts)
2355
-
2319
+ # computes the training loss function in a 1D vector
2356
2320
  @jax.jit
2357
2321
  def _loss_with_key(key, params_1d, model_params):
2358
2322
  policy_params = unravel_fn(params_1d)
2359
- loss_val, (_, model_params) = loss_fn(
2360
- key, policy_params, policy_hyperparams, train_subs, model_params)
2323
+ loss_val, (_, model_params) = self.single_train_loss(
2324
+ key, policy_params, policy_hyperparams, fls, nfls, model_params)
2361
2325
  return loss_val, model_params
2362
2326
 
2327
+ # computes the training loss gradient function in a 1D vector
2328
+ grad_fn = jax.grad(self.single_train_loss, argnums=1, has_aux=True)
2329
+
2363
2330
  @jax.jit
2364
2331
  def _grad_with_key(key, params_1d, model_params):
2365
2332
  policy_params = unravel_fn(params_1d)
2366
- grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
2367
2333
  grad_val, (_, model_params) = grad_fn(
2368
- key, policy_params, policy_hyperparams, train_subs, model_params)
2334
+ key, policy_params, policy_hyperparams, fls, nfls, model_params)
2369
2335
  grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
2370
2336
  return grad_val, model_params
2371
2337
 
2338
+ # store a global reference to the key on every JAX function call and pass when
2339
+ # required by JAX, then update it upon return
2372
2340
  def _loss_function(params_1d):
2373
2341
  nonlocal key
2374
2342
  nonlocal model_params
@@ -2402,9 +2370,7 @@ r"""
2402
2370
 
2403
2371
  :param key: JAX PRNG key (derived from clock if not provided)
2404
2372
  :param epochs: the maximum number of steps of gradient descent
2405
- :param train_seconds: total time allocated for gradient descent
2406
- :param dashboard: dashboard to display training results
2407
- :param dashboard_id: experiment id for the dashboard
2373
+ :param train_seconds: total time allocated for gradient descent
2408
2374
  :param model_params: optional model-parameters to override default
2409
2375
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2410
2376
  weights for sigmoid wrapping boolean actions
@@ -2414,10 +2380,9 @@ r"""
2414
2380
  specified in this instance
2415
2381
  :param print_summary: whether to print planner header and diagnosis
2416
2382
  :param print_progress: whether to print the progress bar during training
2417
- :param print_hyperparams: whether to print list of hyper-parameter settings
2383
+ :param print_hyperparams: whether to print list of hyper-parameter settings
2384
+ :param dashboard_id: experiment id for the dashboard
2418
2385
  :param stopping_rule: stopping criterion
2419
- :param restart_epochs: restart the optimizer from a random policy configuration
2420
- if there is no progress for this many consecutive iterations
2421
2386
  :param test_rolling_window: the test return is averaged on a rolling
2422
2387
  window of the past test_rolling_window returns when updating the best
2423
2388
  parameters found so far
@@ -2442,8 +2407,6 @@ r"""
2442
2407
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
2443
2408
  epochs: int=999999,
2444
2409
  train_seconds: float=120.,
2445
- dashboard: Optional[Any]=None,
2446
- dashboard_id: Optional[str]=None,
2447
2410
  model_params: Optional[Dict[str, Any]]=None,
2448
2411
  policy_hyperparams: Optional[Dict[str, Any]]=None,
2449
2412
  subs: Optional[Dict[str, Any]]=None,
@@ -2451,8 +2414,8 @@ r"""
2451
2414
  print_summary: bool=True,
2452
2415
  print_progress: bool=True,
2453
2416
  print_hyperparams: bool=False,
2417
+ dashboard_id: Optional[str]=None,
2454
2418
  stopping_rule: Optional[JaxPlannerStoppingRule]=None,
2455
- restart_epochs: int=999999,
2456
2419
  test_rolling_window: int=10,
2457
2420
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
2458
2421
  '''Returns a generator for computing an optimal policy or plan.
@@ -2461,9 +2424,7 @@ r"""
2461
2424
 
2462
2425
  :param key: JAX PRNG key (derived from clock if not provided)
2463
2426
  :param epochs: the maximum number of steps of gradient descent
2464
- :param train_seconds: total time allocated for gradient descent
2465
- :param dashboard: dashboard to display training results
2466
- :param dashboard_id: experiment id for the dashboard
2427
+ :param train_seconds: total time allocated for gradient descent
2467
2428
  :param model_params: optional model-parameters to override default
2468
2429
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2469
2430
  weights for sigmoid wrapping boolean actions
@@ -2474,14 +2435,15 @@ r"""
2474
2435
  :param print_summary: whether to print planner header and diagnosis
2475
2436
  :param print_progress: whether to print the progress bar during training
2476
2437
  :param print_hyperparams: whether to print list of hyper-parameter settings
2438
+ :param dashboard_id: experiment id for the dashboard
2477
2439
  :param stopping_rule: stopping criterion
2478
- :param restart_epochs: restart the optimizer from a random policy configuration
2479
- if there is no progress for this many consecutive iterations
2480
2440
  :param test_rolling_window: the test return is averaged on a rolling
2481
2441
  window of the past test_rolling_window returns when updating the best
2482
2442
  parameters found so far
2483
2443
  :param tqdm_position: position of tqdm progress bar (for multiprocessing)
2484
2444
  '''
2445
+
2446
+ # start measuring execution time here, including time spent outside optimize loop
2485
2447
  start_time = time.time()
2486
2448
  elapsed_outside_loop = 0
2487
2449
 
@@ -2489,39 +2451,27 @@ r"""
2489
2451
  # INITIALIZATION OF HYPER-PARAMETERS
2490
2452
  # ======================================================================
2491
2453
 
2492
- # cannot run dashboard with parallel updates
2493
- if dashboard is not None and self.parallel_updates is not None:
2494
- if self.print_warnings:
2495
- message = termcolor.colored(
2496
- '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2497
- 'setting dashboard to None.', 'yellow')
2498
- print(message)
2499
- dashboard = None
2500
-
2501
2454
  # if PRNG key is not provided
2502
2455
  if key is None:
2503
2456
  key = random.PRNGKey(round(time.time() * 1000))
2457
+ if self.print_warnings:
2458
+ print(termcolor.colored(
2459
+ '[WARN] PRNG key is not set: setting from clock.', 'yellow'
2460
+ ))
2504
2461
  dash_key = key[1].item()
2505
2462
 
2506
2463
  # if policy_hyperparams is not provided
2507
2464
  if policy_hyperparams is None:
2508
2465
  if self.print_warnings:
2509
- message = termcolor.colored(
2510
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2511
- 'all action-fluents which could be suboptimal.', 'yellow')
2512
- print(message)
2513
- policy_hyperparams = {action: 1.0
2514
- for action in self.rddl.action_fluents}
2466
+ print(termcolor.colored(
2467
+ '[WARN] policy_hyperparams is not set: setting values to 1.0 for '
2468
+ 'all action-fluents, which could be suboptimal.', 'yellow'
2469
+ ))
2470
+ policy_hyperparams = {action: 1. for action in self.rddl.action_fluents}
2515
2471
 
2516
2472
  # if policy_hyperparams is a scalar
2517
2473
  elif isinstance(policy_hyperparams, (int, float, np.number)):
2518
- if self.print_warnings:
2519
- message = termcolor.colored(
2520
- f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2521
- f'setting this value for all action-fluents.', 'green')
2522
- print(message)
2523
- hyperparam_value = float(policy_hyperparams)
2524
- policy_hyperparams = {action: hyperparam_value
2474
+ policy_hyperparams = {action: float(policy_hyperparams)
2525
2475
  for action in self.rddl.action_fluents}
2526
2476
 
2527
2477
  # fill in missing entries
@@ -2529,12 +2479,12 @@ r"""
2529
2479
  for action in self.rddl.action_fluents:
2530
2480
  if action not in policy_hyperparams:
2531
2481
  if self.print_warnings:
2532
- message = termcolor.colored(
2533
- f'[WARN] policy_hyperparams[{action}] is not set, '
2534
- f'setting 1.0 for missing action-fluents '
2535
- f'which could be suboptimal.', 'yellow')
2536
- print(message)
2537
- policy_hyperparams[action] = 1.0
2482
+ print(termcolor.colored(
2483
+ f'[WARN] policy_hyperparams[{action}] is not set: '
2484
+ f'setting values to 1.0 for missing action-fluents, '
2485
+ f'which could be suboptimal.', 'yellow'
2486
+ ))
2487
+ policy_hyperparams[action] = 1.
2538
2488
 
2539
2489
  # initialize preprocessor
2540
2490
  preproc_key = None
@@ -2548,21 +2498,21 @@ r"""
2548
2498
  print(self.summarize_relaxations())
2549
2499
  if print_hyperparams:
2550
2500
  print(self.summarize_hyperparameters())
2551
- print(f'optimize() call hyper-parameters:\n'
2552
- f' PRNG key ={key}\n'
2553
- f' max_iterations ={epochs}\n'
2554
- f' max_seconds ={train_seconds}\n'
2555
- f' model_params ={model_params}\n'
2556
- f' policy_hyper_params={policy_hyperparams}\n'
2557
- f' override_subs_dict ={subs is not None}\n'
2558
- f' provide_param_guess={guess is not None}\n'
2559
- f' test_rolling_window={test_rolling_window}\n'
2560
- f' dashboard ={dashboard is not None}\n'
2561
- f' dashboard_id ={dashboard_id}\n'
2562
- f' print_summary ={print_summary}\n'
2563
- f' print_progress ={print_progress}\n'
2564
- f' stopping_rule ={stopping_rule}\n'
2565
- f' restart_epochs ={restart_epochs}\n')
2501
+ print(
2502
+ f'[INFO] optimize call hyper-parameters:\n'
2503
+ f' PRNG key ={key}\n'
2504
+ f' max_iterations ={epochs}\n'
2505
+ f' max_seconds ={train_seconds}\n'
2506
+ f' model_params ={model_params}\n'
2507
+ f' policy_hyper_params={policy_hyperparams}\n'
2508
+ f' override_subs_dict ={subs is not None}\n'
2509
+ f' provide_param_guess={guess is not None}\n'
2510
+ f' test_rolling_window={test_rolling_window}\n'
2511
+ f' print_summary ={print_summary}\n'
2512
+ f' print_progress ={print_progress}\n'
2513
+ f' dashboard_id ={dashboard_id}\n'
2514
+ f' stopping_rule ={stopping_rule}\n'
2515
+ )
2566
2516
 
2567
2517
  # ======================================================================
2568
2518
  # INITIALIZATION OF STATE AND POLICY
@@ -2580,23 +2530,23 @@ r"""
2580
2530
  subs[var] = value
2581
2531
  added_pvars_to_subs.append(var)
2582
2532
  if self.print_warnings and added_pvars_to_subs:
2583
- message = termcolor.colored(
2584
- f'[INFO] p-variables {added_pvars_to_subs} is not in '
2585
- f'provided subs, using their initial values.', 'green')
2586
- print(message)
2533
+ print(termcolor.colored(
2534
+ f'[INFO] p-variable(s) {added_pvars_to_subs} are not in '
2535
+ f'provided subs: using their initial values.', 'dark_grey'
2536
+ ))
2587
2537
  train_subs, test_subs = self._batched_init_subs(subs)
2588
2538
 
2589
2539
  # initialize model parameters
2590
2540
  if model_params is None:
2591
- model_params = self.compiled.model_params
2541
+ model_params = self.compiled.model_aux['params']
2592
2542
  model_params = self._broadcast_pytree(model_params)
2593
- model_params_test = self._broadcast_pytree(self.test_compiled.model_params)
2543
+ model_params_test = self._broadcast_pytree(self.test_compiled.model_aux['params'])
2594
2544
 
2595
2545
  # initialize policy parameters
2596
2546
  if guess is None:
2597
2547
  key, subkey = random.split(key)
2598
2548
  policy_params, opt_state, opt_aux = self.initialize(
2599
- subkey, policy_hyperparams, train_subs)
2549
+ subkey, policy_hyperparams, train_subs[0])
2600
2550
  else:
2601
2551
  policy_params = self._broadcast_pytree(guess)
2602
2552
  opt_state, opt_aux = self.init_optimizer(policy_params)
@@ -2606,8 +2556,7 @@ r"""
2606
2556
  pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
2607
2557
  rolling_pgpe_loss = RollingMean(test_rolling_window)
2608
2558
  else:
2609
- pgpe_params, pgpe_opt_state, r_max = None, None, None
2610
- rolling_pgpe_loss = None
2559
+ pgpe_params = pgpe_opt_state = r_max = rolling_pgpe_loss = None
2611
2560
  total_pgpe_it = 0
2612
2561
 
2613
2562
  # ======================================================================
@@ -2615,13 +2564,10 @@ r"""
2615
2564
  # ======================================================================
2616
2565
 
2617
2566
  # initialize running statistics
2618
- if self.parallel_updates is None:
2619
- best_params = policy_params
2620
- else:
2621
- best_params = self.pytree_at(policy_params, 0)
2567
+ best_params = pytree_at(policy_params, 0)
2622
2568
  best_loss, pbest_loss, best_grad = np.inf, np.inf, None
2569
+ best_index = 0
2623
2570
  last_iter_improve = 0
2624
- no_progress_count = 0
2625
2571
  rolling_test_loss = RollingMean(test_rolling_window)
2626
2572
  status = JaxPlannerStatus.NORMAL
2627
2573
  progress_percent = 0
@@ -2630,11 +2576,15 @@ r"""
2630
2576
  if stopping_rule is not None:
2631
2577
  stopping_rule.reset()
2632
2578
 
2633
- # initialize dash board
2579
+ # initialize dashboard
2580
+ dashboard = self.dashboard
2634
2581
  if dashboard is not None:
2635
2582
  dashboard_id = dashboard.register_experiment(
2636
- dashboard_id, dashboard.get_planner_info(self),
2637
- key=dash_key, viz=self.dashboard_viz)
2583
+ dashboard_id,
2584
+ dashboard.get_planner_info(self),
2585
+ key=dash_key,
2586
+ viz=self.dashboard_viz
2587
+ )
2638
2588
 
2639
2589
  # progress bar
2640
2590
  if print_progress:
@@ -2646,8 +2596,8 @@ r"""
2646
2596
 
2647
2597
  # error handlers (to avoid spam messaging)
2648
2598
  policy_constraint_msg_shown = False
2649
- jax_train_msg_shown = False
2650
- jax_test_msg_shown = False
2599
+ jax_train_msg_shown = set()
2600
+ jax_test_msg_shown = set()
2651
2601
 
2652
2602
  # ======================================================================
2653
2603
  # MAIN TRAINING LOOP BEGINS
@@ -2656,7 +2606,7 @@ r"""
2656
2606
  for it in range(epochs):
2657
2607
 
2658
2608
  # ==================================================================
2659
- # NEXT GRADIENT DESCENT STEP
2609
+ # JAXPLAN GRADIENT DESCENT STEP
2660
2610
  # ==================================================================
2661
2611
 
2662
2612
  status = JaxPlannerStatus.NORMAL
@@ -2665,135 +2615,113 @@ r"""
2665
2615
  key, subkey = random.split(key)
2666
2616
  (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2667
2617
  model_params, zero_grads) = self.update(
2668
- subkey, policy_params, policy_hyperparams, train_subs, model_params,
2669
- opt_state, opt_aux)
2618
+ subkey, policy_params, policy_hyperparams, *train_subs, model_params,
2619
+ opt_state, opt_aux
2620
+ )
2670
2621
 
2671
2622
  # update the preprocessor
2672
2623
  if self.preprocessor is not None:
2673
2624
  policy_hyperparams[preproc_key] = self.preprocessor.update(
2674
- train_log['fluents'], policy_hyperparams[preproc_key])
2625
+ train_log['fluents'], policy_hyperparams[preproc_key]
2626
+ )
2675
2627
 
2676
2628
  # evaluate
2677
2629
  test_loss, (test_log, model_params_test) = self.test_loss(
2678
- subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2679
- if self.parallel_updates:
2680
- train_loss = np.asarray(train_loss)
2681
- test_loss = np.asarray(test_loss)
2630
+ subkey, policy_params, policy_hyperparams, *test_subs, model_params_test
2631
+ )
2632
+ train_loss = np.asarray(train_loss)
2633
+ test_loss = np.asarray(test_loss)
2682
2634
  test_loss_smooth = rolling_test_loss.update(test_loss)
2683
2635
 
2684
- # pgpe update of the plan
2636
+ # ==================================================================
2637
+ # PGPE GRADIENT DESCENT STEP
2638
+ # ==================================================================
2639
+
2685
2640
  pgpe_improve = False
2686
2641
  if self.use_pgpe:
2642
+
2643
+ # pgpe update of the plan
2687
2644
  key, subkey = random.split(key)
2688
- pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2689
- self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2690
- policy_hyperparams, test_subs, model_params_test,
2691
- pgpe_opt_state)
2645
+ pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = self.pgpe.update(
2646
+ subkey, pgpe_params, r_max, progress_percent,
2647
+ policy_hyperparams, *test_subs, model_params_test, pgpe_opt_state
2648
+ )
2692
2649
 
2693
2650
  # evaluate
2694
2651
  pgpe_loss, _ = self.test_loss(
2695
- subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2696
- if self.parallel_updates:
2697
- pgpe_loss = np.asarray(pgpe_loss)
2652
+ subkey, pgpe_param, policy_hyperparams, *test_subs, model_params_test)
2653
+ pgpe_loss = np.asarray(pgpe_loss)
2698
2654
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2699
2655
  pgpe_return = -pgpe_loss_smooth
2700
2656
 
2701
2657
  # replace JaxPlan with PGPE if new minimum reached or train loss invalid
2702
- if self.parallel_updates is None:
2703
- if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2704
- policy_params = pgpe_param
2705
- test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2706
- converged = pgpe_converged
2707
- pgpe_improve = True
2708
- total_pgpe_it += 1
2709
- else:
2710
- pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
2711
- if np.any(pgpe_mask):
2712
- policy_params, test_loss, test_loss_smooth, converged = \
2713
- self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
2714
- pgpe_loss, test_loss,
2715
- pgpe_loss_smooth, test_loss_smooth,
2716
- pgpe_converged, converged)
2717
- pgpe_improve = True
2718
- total_pgpe_it += 1
2658
+ pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
2659
+ if np.any(pgpe_mask):
2660
+ policy_params, test_loss, test_loss_smooth, converged = self.merge_pgpe(
2661
+ pgpe_mask, pgpe_param, policy_params,
2662
+ pgpe_loss, test_loss, pgpe_loss_smooth, test_loss_smooth,
2663
+ pgpe_converged, converged
2664
+ )
2665
+ pgpe_improve = True
2666
+ total_pgpe_it += 1
2719
2667
  else:
2720
- pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2668
+ pgpe_loss = pgpe_loss_smooth = pgpe_return = None
2721
2669
 
2722
- # evaluate test losses and record best parameters so far
2723
- if self.parallel_updates is None:
2724
- if test_loss_smooth < best_loss:
2725
- best_params, best_loss, best_grad = \
2726
- policy_params, test_loss_smooth, train_log['grad']
2727
- pbest_loss = best_loss
2728
- else:
2729
- best_index = np.argmin(test_loss_smooth)
2730
- if test_loss_smooth[best_index] < best_loss:
2731
- best_params = self.pytree_at(policy_params, best_index)
2732
- best_grad = self.pytree_at(train_log['grad'], best_index)
2733
- best_loss = test_loss_smooth[best_index]
2734
- pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
2735
-
2736
2670
  # ==================================================================
2737
2671
  # STATUS CHECKS AND LOGGING
2738
2672
  # ==================================================================
2739
-
2673
+
2674
+ # evaluate test losses and record best parameters so far
2675
+ best_index = np.argmin(test_loss_smooth)
2676
+ if test_loss_smooth[best_index] < best_loss:
2677
+ best_params = pytree_at(policy_params, best_index)
2678
+ best_grad = pytree_at(train_log['grad'], best_index)
2679
+ best_loss = test_loss_smooth[best_index]
2680
+ last_iter_improve = it
2681
+ pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
2682
+
2740
2683
  # no progress
2741
- no_progress_flag = (not pgpe_improve) and np.all(zero_grads)
2742
- if no_progress_flag:
2684
+ if (not pgpe_improve) and np.all(zero_grads):
2743
2685
  status = JaxPlannerStatus.NO_PROGRESS
2744
2686
 
2745
2687
  # constraint satisfaction problem
2746
2688
  if not np.all(converged):
2747
2689
  if progress_bar is not None and not policy_constraint_msg_shown:
2748
- message = termcolor.colored(
2749
- '[FAIL] Policy update failed to satisfy action constraints.',
2750
- 'red')
2751
- progress_bar.write(message)
2690
+ progress_bar.write(termcolor.colored(
2691
+ '[FAIL] Policy update violated action constraints.', 'red'
2692
+ ))
2752
2693
  policy_constraint_msg_shown = True
2753
2694
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
2754
2695
 
2755
2696
  # numerical error
2697
+ invalid_loss = not np.any(np.isfinite(train_loss))
2756
2698
  if self.use_pgpe:
2757
- invalid_loss = not (np.any(np.isfinite(train_loss)) or
2758
- np.any(np.isfinite(pgpe_loss)))
2759
- else:
2760
- invalid_loss = not np.any(np.isfinite(train_loss))
2699
+ invalid_loss = invalid_loss and not np.any(np.isfinite(pgpe_loss))
2761
2700
  if invalid_loss:
2762
2701
  if progress_bar is not None:
2763
- message = termcolor.colored(
2764
- f'[FAIL] Planner aborted due to invalid train loss {train_loss}.',
2765
- 'red')
2766
- progress_bar.write(message)
2702
+ progress_bar.write(termcolor.colored(
2703
+ f'[FAIL] Planner aborted early with train loss {train_loss}.', 'red'
2704
+ ))
2767
2705
  status = JaxPlannerStatus.INVALID_GRADIENT
2768
2706
 
2769
2707
  # problem in the model compilation
2770
2708
  if progress_bar is not None:
2771
2709
 
2772
2710
  # train model
2773
- if not jax_train_msg_shown:
2774
- messages = set()
2775
- for error_code in np.unique(train_log['error']):
2776
- messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2777
- if messages:
2778
- messages = '\n '.join(messages)
2779
- message = termcolor.colored(
2780
- f'[FAIL] Compiler encountered the following '
2781
- f'error(s) in the training model:\n {messages}', 'red')
2782
- progress_bar.write(message)
2783
- jax_train_msg_shown = True
2711
+ for error_code in np.unique(train_log['error']):
2712
+ if error_code not in jax_train_msg_shown:
2713
+ jax_train_msg_shown.add(error_code)
2714
+ for message in JaxRDDLCompiler.get_error_messages(error_code):
2715
+ progress_bar.write(termcolor.colored(
2716
+ '[FAIL] Training model error: ' + message, 'red'))
2784
2717
 
2785
2718
  # test model
2786
- if not jax_test_msg_shown:
2787
- messages = set()
2788
- for error_code in np.unique(test_log['error']):
2789
- messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2790
- if messages:
2791
- messages = '\n '.join(messages)
2792
- message = termcolor.colored(
2793
- f'[FAIL] Compiler encountered the following '
2794
- f'error(s) in the testing model:\n {messages}', 'red')
2795
- progress_bar.write(message)
2796
- jax_test_msg_shown = True
2719
+ for error_code in np.unique(test_log['error']):
2720
+ if error_code not in jax_test_msg_shown:
2721
+ jax_test_msg_shown.add(error_code)
2722
+ for message in JaxRDDLCompiler.get_error_messages(error_code):
2723
+ progress_bar.write(termcolor.colored(
2724
+ '[FAIL] Testing model error: ' + message, 'red'))
2797
2725
 
2798
2726
  # reached computation budget
2799
2727
  elapsed = time.time() - start_time - elapsed_outside_loop
@@ -2806,66 +2734,53 @@ r"""
2806
2734
  progress_percent = 100 * min(
2807
2735
  1, max(0, elapsed / train_seconds, it / (epochs - 1)))
2808
2736
  callback = {
2809
- 'status': status,
2810
2737
  'iteration': it,
2738
+ 'elapsed_time': elapsed,
2739
+ 'progress': progress_percent,
2740
+ 'status': status,
2741
+ 'key': key,
2811
2742
  'train_return':-train_loss,
2812
2743
  'test_return':-test_loss_smooth,
2813
2744
  'best_return':-best_loss,
2814
2745
  'pgpe_return': pgpe_return,
2746
+ 'last_iteration_improved': last_iter_improve,
2747
+ 'pgpe_improved': pgpe_improve,
2815
2748
  'params': policy_params,
2816
2749
  'best_params': best_params,
2750
+ 'best_index': best_index,
2817
2751
  'pgpe_params': pgpe_params,
2818
- 'last_iteration_improved': last_iter_improve,
2819
- 'pgpe_improved': pgpe_improve,
2752
+ 'model_params': model_params,
2753
+ 'policy_hyperparams': policy_hyperparams,
2820
2754
  'grad': train_log['grad'],
2821
2755
  'best_grad': best_grad,
2822
- 'updates': train_log['updates'],
2823
- 'elapsed_time': elapsed,
2824
- 'key': key,
2825
- 'model_params': model_params,
2826
- 'progress': progress_percent,
2827
2756
  'train_log': train_log,
2828
- 'policy_hyperparams': policy_hyperparams,
2829
- **test_log
2757
+ 'test_log': test_log
2830
2758
  }
2831
2759
 
2832
- # hard restart
2833
- if guess is None and no_progress_flag:
2834
- no_progress_count += 1
2835
- if no_progress_count > restart_epochs:
2836
- key, subkey = random.split(key)
2837
- policy_params, opt_state, opt_aux = self.initialize(
2838
- subkey, policy_hyperparams, train_subs)
2839
- no_progress_count = 0
2840
- if self.print_warnings and progress_bar is not None:
2841
- message = termcolor.colored(
2842
- f'[INFO] Optimizer restarted at iteration {it} '
2843
- f'due to lack of progress.', 'green')
2844
- progress_bar.write(message)
2845
- else:
2846
- no_progress_count = 0
2847
-
2848
2760
  # stopping condition reached
2849
2761
  if stopping_rule is not None and stopping_rule.monitor(callback):
2850
2762
  if self.print_warnings and progress_bar is not None:
2851
- message = termcolor.colored(
2852
- '[SUCC] Stopping rule has been reached.', 'green')
2853
- progress_bar.write(message)
2763
+ progress_bar.write(termcolor.colored(
2764
+ '[SUCC] Stopping rule has been reached.', 'green'
2765
+ ))
2854
2766
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
2855
2767
 
2856
2768
  # if the progress bar is used
2857
2769
  if print_progress:
2858
2770
  progress_bar.set_description(
2859
- f'{position_str} {it:6} it / {-np.min(train_loss):14.5f} train / '
2860
- f'{-np.min(test_loss_smooth):14.5f} test / {-best_loss:14.5f} best / '
2861
- f'{status.value} status / {total_pgpe_it:6} pgpe',
2862
- refresh=False)
2771
+ f'{position_str} {it} it | {-np.min(train_loss):13.5f} train | '
2772
+ f'{-np.min(test_loss_smooth):13.5f} test | '
2773
+ f'{-best_loss:13.5f} best | '
2774
+ f'{total_pgpe_it} pgpe | {status.value} status',
2775
+ refresh=False
2776
+ )
2863
2777
  progress_bar.set_postfix_str(
2864
- f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
2778
+ f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False
2779
+ )
2865
2780
  progress_bar.update(progress_percent - progress_bar.n)
2866
2781
 
2867
- # dash-board
2868
- if dashboard is not None:
2782
+ # dashboard
2783
+ if dashboard is not None:
2869
2784
  dashboard.update_experiment(dashboard_id, callback)
2870
2785
 
2871
2786
  # yield the callback
@@ -2884,28 +2799,51 @@ r"""
2884
2799
  # release resources
2885
2800
  if print_progress:
2886
2801
  progress_bar.close()
2887
- print()
2888
2802
 
2889
2803
  # summarize and test for convergence
2890
2804
  if print_summary:
2891
- grad_norm = jax.tree_util.tree_map(
2892
- lambda x: np.linalg.norm(x).item(), best_grad)
2805
+
2806
+ # calculate gradient norm
2807
+ grad_norm = jax.tree_util.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
2808
+ grad_norms = jax.tree_util.tree_leaves(grad_norm)
2809
+ max_grad_norm = max(grad_norms) if grad_norms else np.nan
2810
+
2811
+ # calculate best policy return
2812
+ _, (final_log, _) = self.test_loss(
2813
+ key, self._broadcast_pytree(best_params), policy_hyperparams,
2814
+ *test_subs, model_params_test
2815
+ )
2816
+ best_returns = np.ravel(np.sum(final_log['reward'], axis=2))
2817
+ mean, rlo, rhi = self.ci_bootstrap(best_returns)
2818
+
2819
+ # diagnosis
2893
2820
  diagnosis = self._perform_diagnosis(
2894
2821
  last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
2895
- -best_loss, grad_norm)
2896
- print(f'Summary of optimization:\n'
2897
- f' status ={status}\n'
2898
- f' time ={elapsed:.3f} sec.\n'
2899
- f' iterations ={it}\n'
2900
- f' best objective={-best_loss:.6f}\n'
2901
- f' best grad norm={grad_norm}\n'
2902
- f'diagnosis: {diagnosis}\n')
2822
+ -best_loss, max_grad_norm
2823
+ )
2824
+ print(
2825
+ f'[INFO] Summary of optimization:\n'
2826
+ f' status: {status}\n'
2827
+ f' time: {elapsed:.2f} seconds\n'
2828
+ f' iterations: {it}\n'
2829
+ f' best objective: {-best_loss:.5f}\n'
2830
+ f' best grad norm: {max_grad_norm:.5f}\n'
2831
+ f' best cuml reward: Mean = {mean:.5f}, 95% CI [{rlo:.5f}, {rhi:.5f}]\n'
2832
+ f' diagnosis: {diagnosis}\n'
2833
+ )
2903
2834
 
2904
- def _perform_diagnosis(self, last_iter_improve,
2905
- train_return, test_return, best_return, grad_norm):
2906
- grad_norms = jax.tree_util.tree_leaves(grad_norm)
2907
- max_grad_norm = max(grad_norms) if grad_norms else np.nan
2908
- grad_is_zero = np.allclose(max_grad_norm, 0)
2835
+ @staticmethod
2836
+ def ci_bootstrap(returns, confidence=0.95, n_boot=10000):
2837
+ means = np.zeros((n_boot,))
2838
+ for i in range(n_boot):
2839
+ means[i] = np.mean(np.random.choice(returns, size=len(returns), replace=True))
2840
+ lower = np.percentile(means, (1 - confidence) / 2 * 100)
2841
+ upper = np.percentile(means, (1 + confidence) / 2 * 100)
2842
+ mean = np.mean(returns)
2843
+ return mean, lower, upper
2844
+
2845
+ def _perform_diagnosis(self, last_iter_improve, train_return, test_return, best_return,
2846
+ max_grad_norm):
2909
2847
 
2910
2848
  # divergence if the solution is not finite
2911
2849
  if not np.isfinite(train_return):
@@ -2914,64 +2852,61 @@ r"""
2914
2852
  # hit a plateau is likely IF:
2915
2853
  # 1. planner does not improve at all
2916
2854
  # 2. the gradient norm at the best solution is zero
2855
+ grad_is_zero = np.allclose(max_grad_norm, 0)
2917
2856
  if last_iter_improve <= 1:
2918
2857
  if grad_is_zero:
2919
2858
  return termcolor.colored(
2920
- f'[FAIL] No progress was made '
2921
- f'and max grad norm {max_grad_norm:.6f} was zero: '
2922
- f'solver likely stuck in a plateau.', 'red')
2859
+ f'[FAIL] No progress and ||g||={max_grad_norm:.4f}, '
2860
+ f'solver initialized in a plateau.', 'red'
2861
+ )
2923
2862
  else:
2924
2863
  return termcolor.colored(
2925
- f'[FAIL] No progress was made '
2926
- f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2927
- f'learning rate or other hyper-parameters could be suboptimal.',
2928
- 'red')
2864
+ f'[FAIL] No progress and ||g||={max_grad_norm:.4f}, '
2865
+ f'adjust learning rate or other parameters.', 'red'
2866
+ )
2929
2867
 
2930
2868
  # model is likely poor IF:
2931
2869
  # 1. the train and test return disagree
2932
- validation_error = 100 * abs(test_return - train_return) / \
2933
- max(abs(train_return), abs(test_return))
2934
- if not (validation_error < 20):
2870
+ validation_error = (abs(test_return - train_return) /
2871
+ max(abs(train_return), abs(test_return)))
2872
+ if not (validation_error < 0.2):
2935
2873
  return termcolor.colored(
2936
- f'[WARN] Progress was made '
2937
- f'but relative train-test error {validation_error:.6f} was high: '
2938
- f'poor model relaxation around solution or batch size too small.',
2939
- 'yellow')
2874
+ f'[WARN] Progress but large rel. train/test error {validation_error:.4f}, '
2875
+ f'adjust model or batch size.', 'yellow'
2876
+ )
2940
2877
 
2941
2878
  # model likely did not converge IF:
2942
2879
  # 1. the max grad relative to the return is high
2943
2880
  if not grad_is_zero:
2944
- return_to_grad_norm = abs(best_return) / max_grad_norm
2945
- if not (return_to_grad_norm > 1):
2881
+ if not (abs(best_return) > 1.0 * max_grad_norm):
2946
2882
  return termcolor.colored(
2947
- f'[WARN] Progress was made '
2948
- f'but max grad norm {max_grad_norm:.6f} was high: '
2949
- f'solution locally suboptimal, relaxed model nonsmooth around solution, '
2950
- f'or batch size too small.', 'yellow')
2883
+ f'[WARN] Progress but large ||g||={max_grad_norm:.4f}, '
2884
+ f'adjust learning rate or budget.', 'yellow'
2885
+ )
2951
2886
 
2952
2887
  # likely successful
2953
2888
  return termcolor.colored(
2954
- '[SUCC] Planner converged successfully '
2955
- '(note: not all problems can be ruled out).', 'green')
2889
+ '[SUCC] No convergence problems found.', 'green'
2890
+ )
2956
2891
 
2957
2892
  def get_action(self, key: random.PRNGKey,
2958
2893
  params: Pytree,
2959
2894
  step: int,
2960
- subs: Dict[str, Any],
2895
+ state: Dict[str, Any],
2961
2896
  policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
2962
2897
  '''Returns an action dictionary from the policy or plan with the given parameters.
2963
2898
 
2964
2899
  :param key: the JAX PRNG key
2965
2900
  :param params: the trainable parameter PyTree of the policy
2966
2901
  :param step: the time step at which decision is made
2967
- :param subs: the dict of pvariables
2902
+ :param state: the dict of state p-variables
2968
2903
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2969
2904
  weights for sigmoid wrapping boolean actions (optional)
2970
2905
  '''
2971
- subs = subs.copy()
2906
+ state = state.copy()
2972
2907
 
2973
- # check compatibility of the subs dictionary
2974
- for (var, values) in subs.items():
2908
+ # check compatibility of the state dictionary
2909
+ for (var, values) in state.items():
2975
2910
 
2976
2911
  # must not be grounded
2977
2912
  if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
@@ -2985,18 +2920,19 @@ r"""
2985
2920
  dtype = np.result_type(values)
2986
2921
  if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
2987
2922
  if step == 0 and var in self.rddl.observ_fluents:
2988
- subs[var] = self.test_compiled.init_values[var]
2923
+ state[var] = self.test_compiled.init_values[var]
2989
2924
  else:
2990
2925
  if dtype.type is np.str_:
2991
2926
  prange = self.rddl.variable_ranges[var]
2992
- subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
2927
+ state[var] = self.rddl.object_string_to_index_array(prange, state[var])
2993
2928
  else:
2994
2929
  raise ValueError(
2995
2930
  f'Values {values} assigned to p-variable <{var}> are '
2996
- f'non-numeric of type {dtype}.')
2931
+ f'non-numeric of type {dtype}.'
2932
+ )
2997
2933
 
2998
2934
  # cast device arrays to numpy
2999
- actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2935
+ actions = self.test_policy(key, params, policy_hyperparams, step, state)
3000
2936
  actions = jax.tree_util.tree_map(np.asarray, actions)
3001
2937
  return actions
3002
2938
 
@@ -3053,7 +2989,8 @@ class JaxOfflineController(BaseAgent):
3053
2989
  with open(params, 'rb') as file:
3054
2990
  params = pickle.load(file)
3055
2991
 
3056
- # train the policy
2992
+ # train the policy once before starting to step() through the environment
2993
+ # and then execute this policy in open-loop fashion
3057
2994
  self.step = 0
3058
2995
  self.callback = None
3059
2996
  if not self.train_on_reset and not self.params_given:
@@ -3126,22 +3063,33 @@ class JaxOnlineController(BaseAgent):
3126
3063
  self.reset()
3127
3064
 
3128
3065
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
3066
+
3067
+ # we train the policy from the current state every time we step()
3129
3068
  planner = self.planner
3130
3069
  callback = planner.optimize(
3131
3070
  key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
3132
3071
 
3133
3072
  # optimize again if jit compilation takes up the entire time budget
3073
+ # this can be done for several attempts until the optimizer has traced the
3074
+ # computation graph: we report the callback of the successful attempt (if exists)
3134
3075
  attempts = 0
3135
3076
  while attempts < self.max_attempts and callback['iteration'] <= 1:
3136
3077
  attempts += 1
3137
3078
  if self.planner.print_warnings:
3138
- message = termcolor.colored(
3139
- f'[WARN] JIT compilation dominated the execution time: '
3079
+ print(termcolor.colored(
3080
+ f'[INFO] JIT compilation dominated the execution time: '
3140
3081
  f'executing the optimizer again on the traced model '
3141
- f'[attempt {attempts}].', 'yellow')
3142
- print(message)
3082
+ f'[attempt {attempts}].', 'dark_grey'
3083
+ ))
3143
3084
  callback = planner.optimize(
3144
- key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
3085
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
3086
+ if callback['iteration'] <= 1 and self.planner.print_warnings:
3087
+ print(termcolor.colored(
3088
+ f'[FAIL] JIT compilation dominated the execution time and '
3089
+ f'ran out of attempts: increase max_attempts or the training time.', 'red'
3090
+ ))
3091
+
3092
+ # use the last callback obtained
3145
3093
  self.callback = callback
3146
3094
  params = callback['best_params']
3147
3095
  if not self.hyperparams_given: