pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. pyRDDLGym_jax/core/compiler.py +445 -221
  2. pyRDDLGym_jax/core/logic.py +129 -62
  3. pyRDDLGym_jax/core/planner.py +699 -332
  4. pyRDDLGym_jax/core/simulator.py +5 -7
  5. pyRDDLGym_jax/core/tuning.py +23 -12
  6. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  7. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
  8. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  9. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
  10. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  11. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  12. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  13. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  14. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  15. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  16. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  17. pyRDDLGym_jax/examples/run_gym.py +1 -2
  18. pyRDDLGym_jax/examples/run_plan.py +7 -0
  19. pyRDDLGym_jax/examples/run_tune.py +6 -0
  20. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
  21. pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
  22. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
  23. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  24. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  25. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  26. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  27. /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
  28. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  29. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  30. /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  34. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
  35. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
1
+ __version__ = '0.2'
2
+
1
3
  from ast import literal_eval
2
4
  from collections import deque
3
5
  import configparser
6
+ from enum import Enum
4
7
  import haiku as hk
5
8
  import jax
6
9
  import jax.numpy as jnp
@@ -13,11 +16,28 @@ import sys
13
16
  import termcolor
14
17
  import time
15
18
  from tqdm import tqdm
16
- from typing import Callable, Dict, Generator, Set, Sequence, Tuple
19
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
20
+
21
+ Activation = Callable[[jnp.ndarray], jnp.ndarray]
22
+ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
23
+ Kwargs = Dict[str, Any]
24
+ Pytree = Any
25
+
26
+ from pyRDDLGym.core.debug.exception import raise_warning
17
27
 
28
+ # try to import matplotlib, if failed then skip plotting
29
+ try:
30
+ import matplotlib
31
+ import matplotlib.pyplot as plt
32
+ matplotlib.use('TkAgg')
33
+ except Exception:
34
+ raise_warning('matplotlib is not installed, '
35
+ 'plotting functionality is disabled.', 'red')
36
+ plt = None
37
+
18
38
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
39
+ from pyRDDLGym.core.debug.logger import Logger
19
40
  from pyRDDLGym.core.debug.exception import (
20
- raise_warning,
21
41
  RDDLNotImplementedError,
22
42
  RDDLUndefinedVariableError,
23
43
  RDDLTypeError
@@ -37,6 +57,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
37
57
  # - instantiate planner
38
58
  #
39
59
  # ***********************************************************************
60
+
40
61
  def _parse_config_file(path: str):
41
62
  if not os.path.isfile(path):
42
63
  raise FileNotFoundError(f'File {path} does not exist.')
@@ -59,51 +80,94 @@ def _parse_config_string(value: str):
59
80
  return config, args
60
81
 
61
82
 
83
+ def _getattr_any(packages, item):
84
+ for package in packages:
85
+ loaded = getattr(package, item, None)
86
+ if loaded is not None:
87
+ return loaded
88
+ return None
89
+
90
+
62
91
  def _load_config(config, args):
63
92
  model_args = {k: args[k] for (k, _) in config.items('Model')}
64
93
  planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
65
94
  train_args = {k: args[k] for (k, _) in config.items('Training')}
66
95
 
67
- train_args['key'] = jax.random.PRNGKey(train_args['key'])
68
-
69
96
  # read the model settings
70
- tnorm_name = model_args['tnorm']
71
- tnorm_kwargs = model_args['tnorm_kwargs']
72
- logic_name = model_args['logic']
73
- logic_kwargs = model_args['logic_kwargs']
97
+ logic_name = model_args.get('logic', 'FuzzyLogic')
98
+ logic_kwargs = model_args.get('logic_kwargs', {})
99
+ tnorm_name = model_args.get('tnorm', 'ProductTNorm')
100
+ tnorm_kwargs = model_args.get('tnorm_kwargs', {})
101
+ comp_name = model_args.get('complement', 'StandardComplement')
102
+ comp_kwargs = model_args.get('complement_kwargs', {})
103
+ compare_name = model_args.get('comparison', 'SigmoidComparison')
104
+ compare_kwargs = model_args.get('comparison_kwargs', {})
74
105
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
75
- planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
106
+ logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
107
+ logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
76
108
 
77
- # read the optimizer settings
109
+ # read the policy settings
78
110
  plan_method = planner_args.pop('method')
79
111
  plan_kwargs = planner_args.pop('method_kwargs', {})
80
112
 
81
- if 'initializer' in plan_kwargs: # weight initialization
82
- init_name = plan_kwargs['initializer']
83
- init_class = getattr(initializers, init_name)
84
- init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
85
- try:
86
- plan_kwargs['initializer'] = init_class(**init_kwargs)
87
- except:
88
- raise_warning(f'ignoring arguments for initializer <{init_name}>')
89
- plan_kwargs['initializer'] = init_class
90
-
91
- if 'activation' in plan_kwargs: # activation function
92
- plan_kwargs['activation'] = getattr(jax.nn, plan_kwargs['activation'])
113
+ # policy initialization
114
+ plan_initializer = plan_kwargs.get('initializer', None)
115
+ if plan_initializer is not None:
116
+ initializer = _getattr_any(packages=[initializers], item=plan_initializer)
117
+ if initializer is None:
118
+ raise_warning(
119
+ f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
120
+ del plan_kwargs['initializer']
121
+ else:
122
+ init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
123
+ try:
124
+ plan_kwargs['initializer'] = initializer(**init_kwargs)
125
+ except Exception as _:
126
+ raise_warning(
127
+ f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
128
+ plan_kwargs['initializer'] = initializer
93
129
 
130
+ # policy activation
131
+ plan_activation = plan_kwargs.get('activation', None)
132
+ if plan_activation is not None:
133
+ activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
134
+ if activation is None:
135
+ raise_warning(
136
+ f'Ignoring invalid activation <{plan_activation}>.', 'red')
137
+ del plan_kwargs['activation']
138
+ else:
139
+ plan_kwargs['activation'] = activation
140
+
141
+ # read the planner settings
142
+ planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
94
143
  planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
95
- planner_args['optimizer'] = getattr(optax, planner_args['optimizer'])
144
+
145
+ # planner optimizer
146
+ planner_optimizer = planner_args.get('optimizer', None)
147
+ if planner_optimizer is not None:
148
+ optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
149
+ if optimizer is None:
150
+ raise_warning(
151
+ f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
152
+ del planner_args['optimizer']
153
+ else:
154
+ planner_args['optimizer'] = optimizer
155
+
156
+ # read the optimize call settings
157
+ planner_key = train_args.get('key', None)
158
+ if planner_key is not None:
159
+ train_args['key'] = random.PRNGKey(planner_key)
96
160
 
97
161
  return planner_args, plan_kwargs, train_args
98
162
 
99
163
 
100
- def load_config(path: str) -> Tuple[Dict[str, object], ...]:
164
+ def load_config(path: str) -> Tuple[Kwargs, ...]:
101
165
  '''Loads a config file at the specified file path.'''
102
166
  config, args = _parse_config_file(path)
103
167
  return _load_config(config, args)
104
168
 
105
169
 
106
- def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
170
+ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
107
171
  '''Loads config file contents specified explicitly as a string value.'''
108
172
  config, args = _parse_config_string(value)
109
173
  return _load_config(config, args)
@@ -115,6 +179,20 @@ def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
115
179
  # - replace discrete ops in state dynamics/reward with differentiable ones
116
180
  #
117
181
  # ***********************************************************************
182
+
183
+ def _function_discrete_approx_named(logic):
184
+ jax_discrete, jax_param = logic.discrete()
185
+
186
+ def _jax_wrapped_discrete_calc_approx(key, prob, params):
187
+ sample = jax_discrete(key, prob, params)
188
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
189
+ jnp.all(prob >= 0),
190
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
191
+ return sample, out_of_bounds
192
+
193
+ return _jax_wrapped_discrete_calc_approx, jax_param
194
+
195
+
118
196
  class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
119
197
  '''Compiles a RDDL AST representation to an equivalent JAX representation.
120
198
  Unlike its parent class, this class treats all fluents as real-valued, and
@@ -124,7 +202,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
124
202
 
125
203
  def __init__(self, *args,
126
204
  logic: FuzzyLogic=FuzzyLogic(),
127
- cpfs_without_grad: Set=set(),
205
+ cpfs_without_grad: Optional[Set[str]]=None,
128
206
  **kwargs) -> None:
129
207
  '''Creates a new RDDL to Jax compiler, where operations that are not
130
208
  differentiable are converted to approximate forms that have defined
@@ -140,27 +218,30 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
140
218
  '''
141
219
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
142
220
  self.logic = logic
221
+ self.logic.set_use64bit(self.use64bit)
222
+ if cpfs_without_grad is None:
223
+ cpfs_without_grad = set()
143
224
  self.cpfs_without_grad = cpfs_without_grad
144
225
 
145
226
  # actions and CPFs must be continuous
146
- raise_warning(f'Initial values of pvariables will be cast to real.')
227
+ raise_warning('Initial values of pvariables will be cast to real.')
147
228
  for (var, values) in self.init_values.items():
148
229
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
149
230
 
150
231
  # overwrite basic operations with fuzzy ones
151
232
  self.RELATIONAL_OPS = {
152
- '>=': logic.greaterEqual(),
153
- '<=': logic.lessEqual(),
233
+ '>=': logic.greater_equal(),
234
+ '<=': logic.less_equal(),
154
235
  '<': logic.less(),
155
236
  '>': logic.greater(),
156
237
  '==': logic.equal(),
157
- '~=': logic.notEqual()
238
+ '~=': logic.not_equal()
158
239
  }
159
- self.LOGICAL_NOT = logic.Not()
240
+ self.LOGICAL_NOT = logic.logical_not()
160
241
  self.LOGICAL_OPS = {
161
- '^': logic.And(),
162
- '&': logic.And(),
163
- '|': logic.Or(),
242
+ '^': logic.logical_and(),
243
+ '&': logic.logical_and(),
244
+ '|': logic.logical_or(),
164
245
  '~': logic.xor(),
165
246
  '=>': logic.implies(),
166
247
  '<=>': logic.equiv()
@@ -169,15 +250,19 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
169
250
  self.AGGREGATION_OPS['exists'] = logic.exists()
170
251
  self.AGGREGATION_OPS['argmin'] = logic.argmin()
171
252
  self.AGGREGATION_OPS['argmax'] = logic.argmax()
172
- self.KNOWN_UNARY['sgn'] = logic.signum()
253
+ self.KNOWN_UNARY['sgn'] = logic.sgn()
173
254
  self.KNOWN_UNARY['floor'] = logic.floor()
174
255
  self.KNOWN_UNARY['ceil'] = logic.ceil()
175
256
  self.KNOWN_UNARY['round'] = logic.round()
176
257
  self.KNOWN_UNARY['sqrt'] = logic.sqrt()
177
- self.KNOWN_BINARY['div'] = logic.floorDiv()
258
+ self.KNOWN_BINARY['div'] = logic.div()
178
259
  self.KNOWN_BINARY['mod'] = logic.mod()
179
260
  self.KNOWN_BINARY['fmod'] = logic.mod()
180
-
261
+ self.IF_HELPER = logic.control_if()
262
+ self.SWITCH_HELPER = logic.control_switch()
263
+ self.BERNOULLI_HELPER = logic.bernoulli()
264
+ self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
265
+
181
266
  def _jax_stop_grad(self, jax_expr):
182
267
 
183
268
  def _jax_wrapped_stop_grad(x, params, key):
@@ -199,35 +284,13 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
199
284
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
200
285
  return jax_cpfs
201
286
 
202
- def _jax_if_helper(self):
203
- return self.logic.If()
204
-
205
- def _jax_switch_helper(self):
206
- return self.logic.Switch()
207
-
208
287
  def _jax_kron(self, expr, info):
209
288
  if self.logic.verbose:
210
289
  raise_warning('KronDelta will be ignored.')
211
-
212
290
  arg, = expr.args
213
291
  arg = self._jax(arg, info)
214
292
  return arg
215
293
 
216
- def _jax_bernoulli_helper(self):
217
- return self.logic.bernoulli()
218
-
219
- def _jax_discrete_helper(self):
220
- jax_discrete, jax_param = self.logic.discrete()
221
-
222
- def _jax_wrapped_discrete_calc_approx(key, prob, params):
223
- sample = jax_discrete(key, prob, params)
224
- out_of_bounds = jnp.logical_not(jnp.logical_and(
225
- jnp.all(prob >= 0),
226
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
227
- return sample, out_of_bounds
228
-
229
- return _jax_wrapped_discrete_calc_approx, jax_param
230
-
231
294
 
232
295
  # ***********************************************************************
233
296
  # ALL VERSIONS OF JAX PLANS
@@ -236,6 +299,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
236
299
  # - deep reactive policy
237
300
  #
238
301
  # ***********************************************************************
302
+
239
303
  class JaxPlan:
240
304
  '''Base class for all JAX policy representations.'''
241
305
 
@@ -245,15 +309,15 @@ class JaxPlan:
245
309
  self._test_policy = None
246
310
  self._projection = None
247
311
 
248
- def summarize_hyperparameters(self):
312
+ def summarize_hyperparameters(self) -> None:
249
313
  pass
250
314
 
251
315
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
252
- _bounds: Dict,
316
+ _bounds: Bounds,
253
317
  horizon: int) -> None:
254
318
  raise NotImplementedError
255
319
 
256
- def guess_next_epoch(self, params: Dict) -> Dict:
320
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
257
321
  raise NotImplementedError
258
322
 
259
323
  @property
@@ -289,7 +353,8 @@ class JaxPlan:
289
353
  self._projection = value
290
354
 
291
355
  def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
292
- user_bounds: Dict[str, object], horizon: int):
356
+ user_bounds: Bounds,
357
+ horizon: int):
293
358
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
294
359
  for (name, prange) in compiled.rddl.variable_ranges.items():
295
360
  if compiled.rddl.variable_types[name] != 'action-fluent':
@@ -309,8 +374,8 @@ class JaxPlan:
309
374
  else:
310
375
  lower, upper = compiled.constraints.bounds[name]
311
376
  lower, upper = user_bounds.get(name, (lower, upper))
312
- lower = np.asarray(lower, dtype=np.float32)
313
- upper = np.asarray(upper, dtype=np.float32)
377
+ lower = np.asarray(lower, dtype=compiled.REAL)
378
+ upper = np.asarray(upper, dtype=compiled.REAL)
314
379
  lower_finite = np.isfinite(lower)
315
380
  upper_finite = np.isfinite(upper)
316
381
  bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
@@ -336,7 +401,7 @@ class JaxStraightLinePlan(JaxPlan):
336
401
 
337
402
  def __init__(self, initializer: initializers.Initializer=initializers.normal(),
338
403
  wrap_sigmoid: bool=True,
339
- min_action_prob: float=1e-5,
404
+ min_action_prob: float=1e-6,
340
405
  wrap_non_bool: bool=False,
341
406
  wrap_softmax: bool=False,
342
407
  use_new_projection: bool=False,
@@ -371,7 +436,7 @@ class JaxStraightLinePlan(JaxPlan):
371
436
  self._use_new_projection = use_new_projection
372
437
  self._max_constraint_iter = max_constraint_iter
373
438
 
374
- def summarize_hyperparameters(self):
439
+ def summarize_hyperparameters(self) -> None:
375
440
  print(f'policy hyper-parameters:\n'
376
441
  f' initializer ={type(self._initializer_base).__name__}\n'
377
442
  f'constraint-sat strategy (simple):\n'
@@ -383,7 +448,8 @@ class JaxStraightLinePlan(JaxPlan):
383
448
  f' use_new_projection ={self._use_new_projection}')
384
449
 
385
450
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
386
- _bounds: Dict, horizon: int) -> None:
451
+ _bounds: Bounds,
452
+ horizon: int) -> None:
387
453
  rddl = compiled.rddl
388
454
 
389
455
  # calculate the correct action box bounds
@@ -423,7 +489,7 @@ class JaxStraightLinePlan(JaxPlan):
423
489
  def _jax_bool_action_to_param(var, action, hyperparams):
424
490
  if wrap_sigmoid:
425
491
  weight = hyperparams[var]
426
- return (-1.0 / weight) * jnp.log1p(1.0 / action - 2.0)
492
+ return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
427
493
  else:
428
494
  return action
429
495
 
@@ -506,7 +572,7 @@ class JaxStraightLinePlan(JaxPlan):
506
572
  def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, subs):
507
573
  actions = {}
508
574
  for (var, param) in params.items():
509
- action = jnp.asarray(param[step, ...])
575
+ action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
510
576
  if var == bool_key:
511
577
  output = jax.nn.softmax(action)
512
578
  bool_actions = _jax_unstack_bool_from_softmax(output)
@@ -688,7 +754,7 @@ class JaxStraightLinePlan(JaxPlan):
688
754
  # "progress" the plan one step forward and set last action to second-last
689
755
  return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
690
756
 
691
- def guess_next_epoch(self, params: Dict) -> Dict:
757
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
692
758
  next_fn = JaxStraightLinePlan._guess_next_epoch
693
759
  return jax.tree_map(next_fn, params)
694
760
 
@@ -696,10 +762,12 @@ class JaxStraightLinePlan(JaxPlan):
696
762
  class JaxDeepReactivePolicy(JaxPlan):
697
763
  '''A deep reactive policy network implementation in JAX.'''
698
764
 
699
- def __init__(self, topology: Sequence[int],
700
- activation: Callable=jax.nn.relu,
765
+ def __init__(self, topology: Optional[Sequence[int]]=None,
766
+ activation: Activation=jnp.tanh,
701
767
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
702
- normalize: bool=True) -> None:
768
+ normalize: bool=True,
769
+ normalizer_kwargs: Optional[Kwargs]=None,
770
+ wrap_non_bool: bool=False) -> None:
703
771
  '''Creates a new deep reactive policy in JAX.
704
772
 
705
773
  :param neurons: sequence consisting of the number of neurons in each
@@ -707,23 +775,39 @@ class JaxDeepReactivePolicy(JaxPlan):
707
775
  :param activation: function to apply after each layer of the policy
708
776
  :param initializer: weight initialization
709
777
  :param normalize: whether to apply layer norm to the inputs
778
+ :param normalizer_kwargs: if normalize is True, apply additional arguments
779
+ to layer norm
780
+ :param wrap_non_bool: whether to wrap real or int action fluent parameters
781
+ with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
710
782
  '''
711
783
  super(JaxDeepReactivePolicy, self).__init__()
784
+ if topology is None:
785
+ topology = [128, 64]
712
786
  self._topology = topology
713
787
  self._activations = [activation for _ in topology]
714
788
  self._initializer_base = initializer
715
789
  self._initializer = initializer
716
790
  self._normalize = normalize
791
+ if normalizer_kwargs is None:
792
+ normalizer_kwargs = {
793
+ 'create_offset': True, 'create_scale': True,
794
+ 'name': 'input_norm'
795
+ }
796
+ self._normalizer_kwargs = normalizer_kwargs
797
+ self._wrap_non_bool = wrap_non_bool
717
798
 
718
- def summarize_hyperparameters(self):
799
+ def summarize_hyperparameters(self) -> None:
719
800
  print(f'policy hyper-parameters:\n'
720
801
  f' topology ={self._topology}\n'
721
802
  f' activation_fn ={self._activations[0].__name__}\n'
722
803
  f' initializer ={type(self._initializer_base).__name__}\n'
723
- f' apply_layer_norm={self._normalize}')
804
+ f' apply_layer_norm={self._normalize}\n'
805
+ f' layer_norm_args ={self._normalizer_kwargs}\n'
806
+ f' wrap_non_bool ={self._wrap_non_bool}')
724
807
 
725
808
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
726
- _bounds: Dict, horizon: int) -> None:
809
+ _bounds: Bounds,
810
+ horizon: int) -> None:
727
811
  rddl = compiled.rddl
728
812
 
729
813
  # calculate the correct action box bounds
@@ -751,6 +835,7 @@ class JaxDeepReactivePolicy(JaxPlan):
751
835
 
752
836
  ranges = rddl.variable_ranges
753
837
  normalize = self._normalize
838
+ wrap_non_bool = self._wrap_non_bool
754
839
  init = self._initializer
755
840
  layers = list(enumerate(zip(self._topology, self._activations)))
756
841
  layer_sizes = {var: np.prod(shape, dtype=int)
@@ -763,9 +848,7 @@ class JaxDeepReactivePolicy(JaxPlan):
763
848
  # apply layer norm
764
849
  if normalize:
765
850
  normalizer = hk.LayerNorm(
766
- axis=-1, param_axis=-1,
767
- create_offset=True, create_scale=True,
768
- name='input_norm')
851
+ axis=-1, param_axis=-1, **self._normalizer_kwargs)
769
852
  state = normalizer(state)
770
853
 
771
854
  # feed state vector through hidden layers
@@ -789,16 +872,19 @@ class JaxDeepReactivePolicy(JaxPlan):
789
872
  if not use_constraint_satisfaction:
790
873
  actions[var] = jax.nn.sigmoid(output)
791
874
  else:
792
- lower, upper = bounds_safe[var]
793
- action = jnp.select(
794
- condlist=cond_lists[var],
795
- choicelist=[
796
- lower + (upper - lower) * jax.nn.sigmoid(output),
797
- lower + (jax.nn.elu(output) + 1.0),
798
- upper - (jax.nn.elu(-output) + 1.0),
799
- output
800
- ]
801
- )
875
+ if wrap_non_bool:
876
+ lower, upper = bounds_safe[var]
877
+ action = jnp.select(
878
+ condlist=cond_lists[var],
879
+ choicelist=[
880
+ lower + (upper - lower) * jax.nn.sigmoid(output),
881
+ lower + (jax.nn.elu(output) + 1.0),
882
+ upper - (jax.nn.elu(-output) + 1.0),
883
+ output
884
+ ]
885
+ )
886
+ else:
887
+ action = output
802
888
  actions[var] = action
803
889
 
804
890
  # for constraint satisfaction wrap bool actions with softmax
@@ -826,12 +912,17 @@ class JaxDeepReactivePolicy(JaxPlan):
826
912
  actions[name] = action
827
913
  start += size
828
914
  return actions
829
-
915
+
916
+ if rddl.observ_fluents:
917
+ observed_vars = rddl.observ_fluents
918
+ else:
919
+ observed_vars = rddl.state_fluents
920
+
830
921
  # state is concatenated into single tensor
831
922
  def _jax_wrapped_subs_to_state(subs):
832
923
  subs = {var: value
833
924
  for (var, value) in subs.items()
834
- if var in rddl.state_fluents}
925
+ if var in observed_vars}
835
926
  flat_subs = jax.tree_map(jnp.ravel, subs)
836
927
  states = list(flat_subs.values())
837
928
  state = jnp.concatenate(states)
@@ -841,6 +932,10 @@ class JaxDeepReactivePolicy(JaxPlan):
841
932
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
842
933
  state = _jax_wrapped_subs_to_state(subs)
843
934
  actions = predict_fn.apply(params, state)
935
+ if not wrap_non_bool:
936
+ for (var, action) in actions.items():
937
+ if var != bool_key and ranges[var] != 'bool':
938
+ actions[var] = jnp.clip(action, *bounds[var])
844
939
  if use_constraint_satisfaction:
845
940
  bool_actions = _jax_unstack_bool_from_softmax(actions[bool_key])
846
941
  actions.update(bool_actions)
@@ -886,14 +981,14 @@ class JaxDeepReactivePolicy(JaxPlan):
886
981
  def _jax_wrapped_drp_init(key, hyperparams, subs):
887
982
  subs = {var: value[0, ...]
888
983
  for (var, value) in subs.items()
889
- if var in rddl.state_fluents}
984
+ if var in observed_vars}
890
985
  state = _jax_wrapped_subs_to_state(subs)
891
986
  params = predict_fn.init(key, state)
892
987
  return params
893
988
 
894
989
  self.initializer = _jax_wrapped_drp_init
895
990
 
896
- def guess_next_epoch(self, params: Dict) -> Dict:
991
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
897
992
  return params
898
993
 
899
994
 
@@ -904,24 +999,135 @@ class JaxDeepReactivePolicy(JaxPlan):
904
999
  # - more stable but slower line search based planner
905
1000
  #
906
1001
  # ***********************************************************************
1002
+
1003
+ class RollingMean:
1004
+ '''Maintains an estimate of the rolling mean of a stream of real-valued
1005
+ observations.'''
1006
+
1007
+ def __init__(self, window_size: int) -> None:
1008
+ self._window_size = window_size
1009
+ self._memory = deque(maxlen=window_size)
1010
+ self._total = 0
1011
+
1012
+ def update(self, x: float) -> float:
1013
+ memory = self._memory
1014
+ self._total += x
1015
+ if len(memory) == self._window_size:
1016
+ self._total -= memory.popleft()
1017
+ memory.append(x)
1018
+ return self._total / len(memory)
1019
+
1020
+
1021
+ class JaxPlannerPlot:
1022
+ '''Supports plotting and visualization of a JAX policy in real time.'''
1023
+
1024
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int) -> None:
1025
+ self._fig, axes = plt.subplots(1 + len(rddl.action_fluents))
1026
+
1027
+ # prepare the loss plot
1028
+ self._loss_ax = axes[0]
1029
+ self._loss_ax.autoscale(enable=True)
1030
+ self._loss_ax.set_xlabel('decision epoch')
1031
+ self._loss_ax.set_ylabel('loss value')
1032
+ self._loss_plot = self._loss_ax.plot(
1033
+ [], [], linestyle=':', marker='o', markersize=2)[0]
1034
+ self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1035
+
1036
+ # prepare the action plots
1037
+ self._action_ax = {name: axes[idx + 1]
1038
+ for (idx, name) in enumerate(rddl.action_fluents)}
1039
+ self._action_plots = {}
1040
+ for name in rddl.action_fluents:
1041
+ ax = self._action_ax[name]
1042
+ if rddl.variable_ranges[name] == 'bool':
1043
+ vmin, vmax = 0.0, 1.0
1044
+ else:
1045
+ vmin, vmax = None, None
1046
+ action_dim = 1
1047
+ for dim in rddl.object_counts(rddl.variable_params[name]):
1048
+ action_dim *= dim
1049
+ action_plot = ax.pcolormesh(
1050
+ np.zeros((action_dim, horizon)),
1051
+ cmap='seismic', vmin=vmin, vmax=vmax)
1052
+ ax.set_aspect('auto')
1053
+ ax.set_xlabel('decision epoch')
1054
+ ax.set_ylabel(name)
1055
+ plt.colorbar(action_plot, ax=ax)
1056
+ self._action_plots[name] = action_plot
1057
+ self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1058
+ for (name, ax) in self._action_ax.items()}
1059
+
1060
+ plt.tight_layout()
1061
+ plt.show(block=False)
1062
+
1063
+ def redraw(self, xticks, losses, actions) -> None:
1064
+
1065
+ # draw the loss curve
1066
+ self._fig.canvas.restore_region(self._loss_back)
1067
+ self._loss_plot.set_xdata(xticks)
1068
+ self._loss_plot.set_ydata(losses)
1069
+ self._loss_ax.set_xlim([0, len(xticks)])
1070
+ self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
1071
+ self._loss_ax.draw_artist(self._loss_plot)
1072
+ self._fig.canvas.blit(self._loss_ax.bbox)
1073
+
1074
+ # draw the actions
1075
+ for (name, values) in actions.items():
1076
+ values = np.mean(values, axis=0, dtype=float)
1077
+ values = np.reshape(values, newshape=(values.shape[0], -1)).T
1078
+ self._fig.canvas.restore_region(self._action_back[name])
1079
+ self._action_plots[name].set_array(values)
1080
+ self._action_ax[name].draw_artist(self._action_plots[name])
1081
+ self._fig.canvas.blit(self._action_ax[name].bbox)
1082
+ self._action_plots[name].set_clim([np.min(values), np.max(values)])
1083
+ self._fig.canvas.draw()
1084
+ self._fig.canvas.flush_events()
1085
+
1086
+ def close(self) -> None:
1087
+ plt.close(self._fig)
1088
+ del self._loss_ax, self._action_ax, \
1089
+ self._loss_plot, self._action_plots, self._fig, \
1090
+ self._loss_back, self._action_back
1091
+
1092
+
1093
+ class JaxPlannerStatus(Enum):
1094
+ '''Represents the status of a policy update from the JAX planner,
1095
+ including whether the update resulted in nan gradient,
1096
+ whether progress was made, budget was reached, or other information that
1097
+ can be used to monitor and act based on the planner's progress.'''
1098
+
1099
+ NORMAL = 0
1100
+ NO_PROGRESS = 1
1101
+ PRECONDITION_POSSIBLY_UNSATISFIED = 2
1102
+ TIME_BUDGET_REACHED = 3
1103
+ ITER_BUDGET_REACHED = 4
1104
+ INVALID_GRADIENT = 5
1105
+
1106
+ def is_failure(self) -> bool:
1107
+ return self.value >= 3
1108
+
1109
+
907
1110
  class JaxBackpropPlanner:
908
1111
  '''A class for optimizing an action sequence in the given RDDL MDP using
909
1112
  gradient descent.'''
910
1113
 
911
1114
  def __init__(self, rddl: RDDLLiftedModel,
912
1115
  plan: JaxPlan,
913
- batch_size_train: int,
914
- batch_size_test: int=None,
915
- rollout_horizon: int=None,
1116
+ batch_size_train: int=32,
1117
+ batch_size_test: Optional[int]=None,
1118
+ rollout_horizon: Optional[int]=None,
916
1119
  use64bit: bool=False,
917
- action_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={},
1120
+ action_bounds: Optional[Bounds]=None,
918
1121
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
919
- optimizer_kwargs: Dict[str, object]={'learning_rate': 0.1},
920
- clip_grad: float=None,
1122
+ optimizer_kwargs: Optional[Kwargs]=None,
1123
+ clip_grad: Optional[float]=None,
921
1124
  logic: FuzzyLogic=FuzzyLogic(),
922
1125
  use_symlog_reward: bool=False,
923
- utility=jnp.mean,
924
- cpfs_without_grad: Set=set()) -> None:
1126
+ utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1127
+ utility_kwargs: Optional[Kwargs]=None,
1128
+ cpfs_without_grad: Optional[Set[str]]=None,
1129
+ compile_non_fluent_exact: bool=True,
1130
+ logger: Optional[Logger]=None) -> None:
925
1131
  '''Creates a new gradient-based algorithm for optimizing action sequences
926
1132
  (plan) in the given RDDL. Some operations will be converted to their
927
1133
  differentiable counterparts; the specific operations can be customized
@@ -946,9 +1152,16 @@ class JaxBackpropPlanner:
946
1152
  :param use_symlog_reward: whether to use the symlog transform on the
947
1153
  reward as a form of normalization
948
1154
  :param utility: how to aggregate return observations to compute utility
949
- of a policy or plan
1155
+ of a policy or plan; must be either a function mapping jax array to a
1156
+ scalar, or a a string identifying the utility function by name
1157
+ ("mean", "mean_var", "entropic", or "cvar" are currently supported)
1158
+ :param utility_kwargs: additional keyword arguments to pass hyper-
1159
+ parameters to the utility function call
950
1160
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
951
1161
  through gradient trick)
1162
+ :param compile_non_fluent_exact: whether non-fluent expressions
1163
+ are always compiled using exact JAX expressions
1164
+ :param logger: to log information about compilation to file
952
1165
  '''
953
1166
  self.rddl = rddl
954
1167
  self.plan = plan
@@ -959,22 +1172,25 @@ class JaxBackpropPlanner:
959
1172
  if rollout_horizon is None:
960
1173
  rollout_horizon = rddl.horizon
961
1174
  self.horizon = rollout_horizon
1175
+ if action_bounds is None:
1176
+ action_bounds = {}
962
1177
  self._action_bounds = action_bounds
963
1178
  self.use64bit = use64bit
964
1179
  self._optimizer_name = optimizer
1180
+ if optimizer_kwargs is None:
1181
+ optimizer_kwargs = {'learning_rate': 0.1}
965
1182
  self._optimizer_kwargs = optimizer_kwargs
966
1183
  self.clip_grad = clip_grad
967
1184
 
968
1185
  # set optimizer
969
1186
  try:
970
1187
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
971
- except:
1188
+ except Exception as _:
972
1189
  raise_warning(
973
1190
  'Failed to inject hyperparameters into optax optimizer, '
974
1191
  'rolling back to safer method: please note that modification of '
975
1192
  'optimizer hyperparameters will not work, and it is '
976
- 'recommended to update your packages and Python distribution.',
977
- 'red')
1193
+ 'recommended to update optax and related packages.', 'red')
978
1194
  optimizer = optimizer(**optimizer_kwargs)
979
1195
  if clip_grad is None:
980
1196
  self.optimizer = optimizer
@@ -983,22 +1199,68 @@ class JaxBackpropPlanner:
983
1199
  optax.clip(clip_grad),
984
1200
  optimizer
985
1201
  )
986
-
1202
+
1203
+ # set utility
1204
+ if isinstance(utility, str):
1205
+ utility = utility.lower()
1206
+ if utility == 'mean':
1207
+ utility_fn = jnp.mean
1208
+ elif utility == 'mean_var':
1209
+ utility_fn = mean_variance_utility
1210
+ elif utility == 'entropic':
1211
+ utility_fn = entropic_utility
1212
+ elif utility == 'cvar':
1213
+ utility_fn = cvar_utility
1214
+ else:
1215
+ raise RDDLNotImplementedError(
1216
+ f'Utility function <{utility}> is not supported: '
1217
+ 'must be one of ["mean", "mean_var", "entropic", "cvar"].')
1218
+ else:
1219
+ utility_fn = utility
1220
+ self.utility = utility_fn
1221
+
1222
+ if utility_kwargs is None:
1223
+ utility_kwargs = {}
1224
+ self.utility_kwargs = utility_kwargs
1225
+
987
1226
  self.logic = logic
1227
+ self.logic.set_use64bit(self.use64bit)
988
1228
  self.use_symlog_reward = use_symlog_reward
989
- self.utility = utility
1229
+ if cpfs_without_grad is None:
1230
+ cpfs_without_grad = set()
990
1231
  self.cpfs_without_grad = cpfs_without_grad
1232
+ self.compile_non_fluent_exact = compile_non_fluent_exact
1233
+ self.logger = logger
991
1234
 
992
1235
  self._jax_compile_rddl()
993
1236
  self._jax_compile_optimizer()
994
-
995
- def summarize_hyperparameters(self):
996
- print(f'objective and relaxations:\n'
997
- f' objective_fn ={self.utility.__name__}\n'
1237
+
1238
+ def _summarize_system(self) -> None:
1239
+ try:
1240
+ jaxlib_version = jax._src.lib.version_str
1241
+ except Exception as _:
1242
+ jaxlib_version = 'N/A'
1243
+ try:
1244
+ devices_short = ', '.join(
1245
+ map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1246
+ except Exception as _:
1247
+ devices_short = 'N/A'
1248
+ print('\n'
1249
+ f'JAX Planner version {__version__}\n'
1250
+ f'Python {sys.version}\n'
1251
+ f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1252
+ f'numpy {np.__version__}\n'
1253
+ f'devices: {devices_short}\n')
1254
+
1255
+ def summarize_hyperparameters(self) -> None:
1256
+ print(f'objective hyper-parameters:\n'
1257
+ f' utility_fn ={self.utility.__name__}\n'
1258
+ f' utility args ={self.utility_kwargs}\n'
998
1259
  f' use_symlog ={self.use_symlog_reward}\n'
999
1260
  f' lookahead ={self.horizon}\n'
1000
- f' model relaxation={type(self.logic).__name__}\n'
1001
1261
  f' action_bounds ={self._action_bounds}\n'
1262
+ f' fuzzy logic type={type(self.logic).__name__}\n'
1263
+ f' nonfluents exact={self.compile_non_fluent_exact}\n'
1002
1264
  f' cpfs_no_gradient={self.cpfs_without_grad}\n'
1003
1265
  f'optimizer hyper-parameters:\n'
1004
1266
  f' use_64_bit ={self.use64bit}\n'
@@ -1010,6 +1272,10 @@ class JaxBackpropPlanner:
1010
1272
  self.plan.summarize_hyperparameters()
1011
1273
  self.logic.summarize_hyperparameters()
1012
1274
 
1275
+ # ===========================================================================
1276
+ # COMPILATION SUBROUTINES
1277
+ # ===========================================================================
1278
+
1013
1279
  def _jax_compile_rddl(self):
1014
1280
  rddl = self.rddl
1015
1281
 
@@ -1017,13 +1283,18 @@ class JaxBackpropPlanner:
1017
1283
  self.compiled = JaxRDDLCompilerWithGrad(
1018
1284
  rddl=rddl,
1019
1285
  logic=self.logic,
1286
+ logger=self.logger,
1020
1287
  use64bit=self.use64bit,
1021
- cpfs_without_grad=self.cpfs_without_grad)
1022
- self.compiled.compile()
1288
+ cpfs_without_grad=self.cpfs_without_grad,
1289
+ compile_non_fluent_exact=self.compile_non_fluent_exact)
1290
+ self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1023
1291
 
1024
1292
  # Jax compilation of the exact RDDL for testing
1025
- self.test_compiled = JaxRDDLCompiler(rddl=rddl, use64bit=self.use64bit)
1026
- self.test_compiled.compile()
1293
+ self.test_compiled = JaxRDDLCompiler(
1294
+ rddl=rddl,
1295
+ logger=self.logger,
1296
+ use64bit=self.use64bit)
1297
+ self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1027
1298
 
1028
1299
  def _jax_compile_optimizer(self):
1029
1300
 
@@ -1051,11 +1322,10 @@ class JaxBackpropPlanner:
1051
1322
 
1052
1323
  # losses
1053
1324
  train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
1054
- self.train_loss = jax.jit(train_loss)
1055
1325
  self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
1056
1326
 
1057
1327
  # optimization
1058
- self.update = jax.jit(self._jax_update(train_loss))
1328
+ self.update = self._jax_update(train_loss)
1059
1329
 
1060
1330
  def _jax_return(self, use_symlog):
1061
1331
  gamma = self.rddl.discount
@@ -1068,13 +1338,14 @@ class JaxBackpropPlanner:
1068
1338
  rewards = rewards * discount[jnp.newaxis, ...]
1069
1339
  returns = jnp.sum(rewards, axis=1)
1070
1340
  if use_symlog:
1071
- returns = jnp.sign(returns) * jnp.log1p(jnp.abs(returns))
1341
+ returns = jnp.sign(returns) * jnp.log(1.0 + jnp.abs(returns))
1072
1342
  return returns
1073
1343
 
1074
1344
  return _jax_wrapped_returns
1075
1345
 
1076
1346
  def _jax_loss(self, rollouts, use_symlog=False):
1077
- utility_fn = self.utility
1347
+ utility_fn = self.utility
1348
+ utility_kwargs = self.utility_kwargs
1078
1349
  _jax_wrapped_returns = self._jax_return(use_symlog)
1079
1350
 
1080
1351
  # the loss is the average cumulative reward across all roll-outs
@@ -1083,7 +1354,7 @@ class JaxBackpropPlanner:
1083
1354
  log = rollouts(key, policy_params, hyperparams, subs, model_params)
1084
1355
  rewards = log['reward']
1085
1356
  returns = _jax_wrapped_returns(rewards)
1086
- utility = utility_fn(returns)
1357
+ utility = utility_fn(returns, **utility_kwargs)
1087
1358
  loss = -utility
1088
1359
  return loss, log
1089
1360
 
@@ -1096,7 +1367,7 @@ class JaxBackpropPlanner:
1096
1367
  def _jax_wrapped_init_policy(key, hyperparams, subs):
1097
1368
  policy_params = init(key, hyperparams, subs)
1098
1369
  opt_state = optimizer.init(policy_params)
1099
- return policy_params, opt_state
1370
+ return policy_params, opt_state, None
1100
1371
 
1101
1372
  return _jax_wrapped_init_policy
1102
1373
 
@@ -1107,17 +1378,18 @@ class JaxBackpropPlanner:
1107
1378
  # calculate the plan gradient w.r.t. return loss and update optimizer
1108
1379
  # also perform a projection step to satisfy constraints on actions
1109
1380
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1110
- subs, model_params, opt_state):
1111
- grad_fn = jax.grad(loss, argnums=1, has_aux=True)
1112
- grad, log = grad_fn(key, policy_params, hyperparams, subs, model_params)
1381
+ subs, model_params, opt_state, opt_aux):
1382
+ grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1383
+ (loss_val, log), grad = grad_fn(
1384
+ key, policy_params, hyperparams, subs, model_params)
1113
1385
  updates, opt_state = optimizer.update(grad, opt_state)
1114
1386
  policy_params = optax.apply_updates(policy_params, updates)
1115
1387
  policy_params, converged = projection(policy_params, hyperparams)
1116
1388
  log['grad'] = grad
1117
1389
  log['updates'] = updates
1118
- return policy_params, converged, opt_state, log
1390
+ return policy_params, converged, opt_state, None, loss_val, log
1119
1391
 
1120
- return _jax_wrapped_plan_update
1392
+ return jax.jit(_jax_wrapped_plan_update)
1121
1393
 
1122
1394
  def _batched_init_subs(self, subs):
1123
1395
  rddl = self.rddl
@@ -1145,13 +1417,15 @@ class JaxBackpropPlanner:
1145
1417
 
1146
1418
  return init_train, init_test
1147
1419
 
1148
- def optimize(self, *args, return_callback: bool=False, **kwargs) -> object:
1149
- ''' Compute an optimal straight-line plan. Returns the parameters
1150
- for the optimized policy.
1420
+ # ===========================================================================
1421
+ # OPTIMIZE API
1422
+ # ===========================================================================
1423
+
1424
+ def optimize(self, *args, **kwargs) -> Dict[str, Any]:
1425
+ ''' Compute an optimal policy or plan. Return the callback from training.
1151
1426
 
1152
- :param key: JAX PRNG key
1427
+ :param key: JAX PRNG key (derived from clock if not provided)
1153
1428
  :param epochs: the maximum number of steps of gradient descent
1154
- :param the maximum number of steps of gradient descent
1155
1429
  :param train_seconds: total time allocated for gradient descent
1156
1430
  :param plot_step: frequency to plot the plan and save result to disk
1157
1431
  :param model_params: optional model-parameters to override default
@@ -1162,33 +1436,44 @@ class JaxBackpropPlanner:
1162
1436
  :param guess: initial policy parameters: if None will use the initializer
1163
1437
  specified in this instance
1164
1438
  :param verbose: not print (0), print summary (1), print progress (2)
1165
- :param return_callback: whether to return the callback from training
1166
- instead of the parameters
1439
+ :param test_rolling_window: the test return is averaged on a rolling
1440
+ window of the past test_rolling_window returns when updating the best
1441
+ parameters found so far
1442
+ :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1167
1443
  '''
1168
1444
  it = self.optimize_generator(*args, **kwargs)
1169
- callback = deque(it, maxlen=1).pop()
1170
- if return_callback:
1171
- return callback
1445
+
1446
+ # if the python is C-compiled then the deque is native C and much faster
1447
+ # than naively exhausting iterator, but not if the python is some other
1448
+ # version (e.g. PyPi); for details, see
1449
+ # https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
1450
+ callback = None
1451
+ if sys.implementation.name == 'cpython':
1452
+ last_callback = deque(it, maxlen=1)
1453
+ if last_callback:
1454
+ callback = last_callback.pop()
1172
1455
  else:
1173
- return callback['best_params']
1456
+ for callback in it:
1457
+ pass
1458
+ return callback
1174
1459
 
1175
- def optimize_generator(self, key: random.PRNGKey,
1460
+ def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1176
1461
  epochs: int=999999,
1177
1462
  train_seconds: float=120.,
1178
- plot_step: int=None,
1179
- model_params: Dict[str, object]=None,
1180
- policy_hyperparams: Dict[str, object]=None,
1181
- subs: Dict[str, object]=None,
1182
- guess: Dict[str, object]=None,
1463
+ plot_step: Optional[int]=None,
1464
+ model_params: Optional[Dict[str, Any]]=None,
1465
+ policy_hyperparams: Optional[Dict[str, Any]]=None,
1466
+ subs: Optional[Dict[str, Any]]=None,
1467
+ guess: Optional[Pytree]=None,
1183
1468
  verbose: int=2,
1184
- tqdm_position: int=None) -> Generator[Dict[str, object], None, None]:
1185
- '''Returns a generator for computing an optimal straight-line plan.
1469
+ test_rolling_window: int=10,
1470
+ tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1471
+ '''Returns a generator for computing an optimal policy or plan.
1186
1472
  Generator can be iterated over to lazily optimize the plan, yielding
1187
1473
  a dictionary of intermediate computations.
1188
1474
 
1189
- :param key: JAX PRNG key
1475
+ :param key: JAX PRNG key (derived from clock if not provided)
1190
1476
  :param epochs: the maximum number of steps of gradient descent
1191
- :param the maximum number of steps of gradient descent
1192
1477
  :param train_seconds: total time allocated for gradient descent
1193
1478
  :param plot_step: frequency to plot the plan and save result to disk
1194
1479
  :param model_params: optional model-parameters to override default
@@ -1199,26 +1484,53 @@ class JaxBackpropPlanner:
1199
1484
  :param guess: initial policy parameters: if None will use the initializer
1200
1485
  specified in this instance
1201
1486
  :param verbose: not print (0), print summary (1), print progress (2)
1487
+ :param test_rolling_window: the test return is averaged on a rolling
1488
+ window of the past test_rolling_window returns when updating the best
1489
+ parameters found so far
1202
1490
  :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1203
1491
  '''
1204
1492
  verbose = int(verbose)
1205
1493
  start_time = time.time()
1206
1494
  elapsed_outside_loop = 0
1207
1495
 
1496
+ # if PRNG key is not provided
1497
+ if key is None:
1498
+ key = random.PRNGKey(round(time.time() * 1000))
1499
+
1500
+ # if policy_hyperparams is not provided
1501
+ if policy_hyperparams is None:
1502
+ raise_warning('policy_hyperparams is not set, setting 1.0 for '
1503
+ 'all action-fluents which could be suboptimal.')
1504
+ policy_hyperparams = {action: 1.0
1505
+ for action in self.rddl.action_fluents}
1506
+
1507
+ # if policy_hyperparams is a scalar
1508
+ elif isinstance(policy_hyperparams, (int, float, np.number)):
1509
+ raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
1510
+ 'setting this value for all action-fluents.')
1511
+ hyperparam_value = float(policy_hyperparams)
1512
+ policy_hyperparams = {action: hyperparam_value
1513
+ for action in self.rddl.action_fluents}
1514
+
1208
1515
  # print summary of parameters:
1209
1516
  if verbose >= 1:
1210
- print('==============================================\n'
1211
- 'JAX PLANNER PARAMETER SUMMARY\n'
1212
- '==============================================')
1517
+ self._summarize_system()
1213
1518
  self.summarize_hyperparameters()
1214
1519
  print(f'optimize() call hyper-parameters:\n'
1520
+ f' PRNG key ={key}\n'
1215
1521
  f' max_iterations ={epochs}\n'
1216
1522
  f' max_seconds ={train_seconds}\n'
1217
1523
  f' model_params ={model_params}\n'
1218
1524
  f' policy_hyper_params={policy_hyperparams}\n'
1219
1525
  f' override_subs_dict ={subs is not None}\n'
1220
- f' provide_param_guess={guess is not None}\n'
1221
- f' plot_frequency ={plot_step}\n')
1526
+ f' provide_param_guess={guess is not None}\n'
1527
+ f' test_rolling_window={test_rolling_window}\n'
1528
+ f' plot_frequency ={plot_step}\n'
1529
+ f' verbose ={verbose}\n')
1530
+ if verbose >= 2 and self.compiled.relaxations:
1531
+ print('Some RDDL operations are non-differentiable, '
1532
+ 'replacing them with differentiable relaxations:')
1533
+ print(self.compiled.summarize_model_relaxations())
1222
1534
 
1223
1535
  # compute a batched version of the initial values
1224
1536
  if subs is None:
@@ -1245,14 +1557,26 @@ class JaxBackpropPlanner:
1245
1557
  # initialize policy parameters
1246
1558
  if guess is None:
1247
1559
  key, subkey = random.split(key)
1248
- policy_params, opt_state = self.initialize(
1560
+ policy_params, opt_state, opt_aux = self.initialize(
1249
1561
  subkey, policy_hyperparams, train_subs)
1250
1562
  else:
1251
1563
  policy_params = guess
1252
1564
  opt_state = self.optimizer.init(policy_params)
1565
+ opt_aux = None
1566
+
1567
+ # initialize running statistics
1253
1568
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
1254
1569
  last_iter_improve = 0
1570
+ rolling_test_loss = RollingMean(test_rolling_window)
1255
1571
  log = {}
1572
+ status = JaxPlannerStatus.NORMAL
1573
+
1574
+ # initialize plot area
1575
+ if plot_step is None or plot_step <= 0 or plt is None:
1576
+ plot = None
1577
+ else:
1578
+ plot = JaxPlannerPlot(self.rddl, self.horizon)
1579
+ xticks, loss_values = [], []
1256
1580
 
1257
1581
  # training loop
1258
1582
  iters = range(epochs)
@@ -1260,25 +1584,25 @@ class JaxBackpropPlanner:
1260
1584
  iters = tqdm(iters, total=100, position=tqdm_position)
1261
1585
 
1262
1586
  for it in iters:
1587
+ status = JaxPlannerStatus.NORMAL
1263
1588
 
1264
1589
  # update the parameters of the plan
1265
- key, subkey1, subkey2, subkey3 = random.split(key, num=4)
1266
- policy_params, converged, opt_state, train_log = self.update(
1267
- subkey1, policy_params, policy_hyperparams,
1268
- train_subs, model_params, opt_state)
1590
+ key, subkey = random.split(key)
1591
+ policy_params, converged, opt_state, opt_aux, train_loss, train_log = \
1592
+ self.update(subkey, policy_params, policy_hyperparams,
1593
+ train_subs, model_params, opt_state, opt_aux)
1269
1594
  if not np.all(converged):
1270
1595
  raise_warning(
1271
1596
  'Projected gradient method for satisfying action concurrency '
1272
1597
  'constraints reached the iteration limit: plan is possibly '
1273
1598
  'invalid for the current instance.', 'red')
1599
+ status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1274
1600
 
1275
1601
  # evaluate losses
1276
- train_loss, _ = self.train_loss(
1277
- subkey2, policy_params, policy_hyperparams,
1278
- train_subs, model_params)
1279
1602
  test_loss, log = self.test_loss(
1280
- subkey3, policy_params, policy_hyperparams,
1603
+ subkey, policy_params, policy_hyperparams,
1281
1604
  test_subs, model_params_test)
1605
+ test_loss = rolling_test_loss.update(test_loss)
1282
1606
 
1283
1607
  # record the best plan so far
1284
1608
  if test_loss < best_loss:
@@ -1287,21 +1611,45 @@ class JaxBackpropPlanner:
1287
1611
  last_iter_improve = it
1288
1612
 
1289
1613
  # save the plan figure
1290
- if plot_step is not None and it % plot_step == 0:
1291
- self._plot_actions(
1292
- key, policy_params, policy_hyperparams, test_subs, it)
1614
+ if plot is not None and it % plot_step == 0:
1615
+ xticks.append(it // plot_step)
1616
+ loss_values.append(test_loss.item())
1617
+ action_values = {name: values
1618
+ for (name, values) in log['fluents'].items()
1619
+ if name in self.rddl.action_fluents}
1620
+ plot.redraw(xticks, loss_values, action_values)
1293
1621
 
1294
1622
  # if the progress bar is used
1295
1623
  elapsed = time.time() - start_time - elapsed_outside_loop
1296
1624
  if verbose >= 2:
1297
1625
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1298
1626
  iters.set_description(
1299
- f'[{tqdm_position}] {it:6} it / {-train_loss:14.4f} train / '
1300
- f'{-test_loss:14.4f} test / {-best_loss:14.4f} best')
1627
+ f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
1628
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
1629
+
1630
+ # reached computation budget
1631
+ if elapsed >= train_seconds:
1632
+ status = JaxPlannerStatus.TIME_BUDGET_REACHED
1633
+ if it >= epochs - 1:
1634
+ status = JaxPlannerStatus.ITER_BUDGET_REACHED
1635
+
1636
+ # numerical error
1637
+ if not np.isfinite(train_loss):
1638
+ raise_warning(
1639
+ f'Aborting JAX planner due to invalid train loss {train_loss}.',
1640
+ 'red')
1641
+ status = JaxPlannerStatus.INVALID_GRADIENT
1642
+
1643
+ # no progress
1644
+ grad_norm_zero, _ = jax.tree_util.tree_flatten(
1645
+ jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1646
+ if np.all(grad_norm_zero):
1647
+ status = JaxPlannerStatus.NO_PROGRESS
1301
1648
 
1302
1649
  # return a callback
1303
1650
  start_time_outside = time.time()
1304
1651
  yield {
1652
+ 'status': status,
1305
1653
  'iteration': it,
1306
1654
  'train_return':-train_loss,
1307
1655
  'test_return':-test_loss,
@@ -1318,16 +1666,15 @@ class JaxBackpropPlanner:
1318
1666
  }
1319
1667
  elapsed_outside_loop += (time.time() - start_time_outside)
1320
1668
 
1321
- # reached time budget
1322
- if elapsed >= train_seconds:
1323
- break
1324
-
1325
- # numerical error
1326
- if not np.isfinite(train_loss):
1669
+ # abortion check
1670
+ if status.is_failure():
1327
1671
  break
1328
-
1672
+
1673
+ # release resources
1329
1674
  if verbose >= 2:
1330
1675
  iters.close()
1676
+ if plot is not None:
1677
+ plot.close()
1331
1678
 
1332
1679
  # validate the test return
1333
1680
  if log:
@@ -1337,24 +1684,23 @@ class JaxBackpropPlanner:
1337
1684
  if messages:
1338
1685
  messages = '\n'.join(messages)
1339
1686
  raise_warning('The JAX compiler encountered the following '
1340
- 'problems in the original RDDL '
1687
+ 'error(s) in the original RDDL formulation '
1341
1688
  f'during test evaluation:\n{messages}', 'red')
1342
1689
 
1343
1690
  # summarize and test for convergence
1344
1691
  if verbose >= 1:
1345
- grad_norm = jax.tree_map(
1346
- lambda x: np.array(jnp.linalg.norm(x)).item(), best_grad)
1692
+ grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1347
1693
  diagnosis = self._perform_diagnosis(
1348
- last_iter_improve, it,
1349
- -train_loss, -test_loss, -best_loss, grad_norm)
1694
+ last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
1350
1695
  print(f'summary of optimization:\n'
1696
+ f' status_code ={status}\n'
1351
1697
  f' time_elapsed ={elapsed}\n'
1352
1698
  f' iterations ={it}\n'
1353
1699
  f' best_objective={-best_loss}\n'
1354
- f' grad_norm ={grad_norm}\n'
1700
+ f' best_grad_norm={grad_norm}\n'
1355
1701
  f'diagnosis: {diagnosis}\n')
1356
1702
 
1357
- def _perform_diagnosis(self, last_iter_improve, total_it,
1703
+ def _perform_diagnosis(self, last_iter_improve,
1358
1704
  train_return, test_return, best_return, grad_norm):
1359
1705
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
1360
1706
  grad_is_zero = np.allclose(max_grad_norm, 0)
@@ -1373,20 +1719,20 @@ class JaxBackpropPlanner:
1373
1719
  if grad_is_zero:
1374
1720
  return termcolor.colored(
1375
1721
  '[FAILURE] no progress was made, '
1376
- f'and max grad norm = {max_grad_norm}, '
1377
- 'likely stuck in a plateau.', 'red')
1722
+ f'and max grad norm {max_grad_norm:.6f} is zero: '
1723
+ 'solver likely stuck in a plateau.', 'red')
1378
1724
  else:
1379
1725
  return termcolor.colored(
1380
1726
  '[FAILURE] no progress was made, '
1381
- f'but max grad norm = {max_grad_norm} > 0, '
1382
- 'likely due to bad l.r. or other hyper-parameter.', 'red')
1727
+ f'but max grad norm {max_grad_norm:.6f} is non-zero: '
1728
+ 'likely poor learning rate or other hyper-parameter.', 'red')
1383
1729
 
1384
1730
  # model is likely poor IF:
1385
1731
  # 1. the train and test return disagree
1386
1732
  if not (validation_error < 20):
1387
1733
  return termcolor.colored(
1388
1734
  '[WARNING] progress was made, '
1389
- f'but relative train test error = {validation_error} is high, '
1735
+ f'but relative train-test error {validation_error:.6f} is high: '
1390
1736
  'likely poor model relaxation around the solution, '
1391
1737
  'or the batch size is too small.', 'yellow')
1392
1738
 
@@ -1397,208 +1743,216 @@ class JaxBackpropPlanner:
1397
1743
  if not (return_to_grad_norm > 1):
1398
1744
  return termcolor.colored(
1399
1745
  '[WARNING] progress was made, '
1400
- f'but max grad norm = {max_grad_norm} is high, '
1401
- 'likely indicates the solution is not locally optimal, '
1402
- 'or the model is not smooth around the solution, '
1746
+ f'but max grad norm {max_grad_norm:.6f} is high: '
1747
+ 'likely the solution is not locally optimal, '
1748
+ 'or the relaxed model is not smooth around the solution, '
1403
1749
  'or the batch size is too small.', 'yellow')
1404
1750
 
1405
1751
  # likely successful
1406
1752
  return termcolor.colored(
1407
- '[SUCCESS] planner appears to have converged successfully '
1753
+ '[SUCCESS] planner has converged successfully '
1408
1754
  '(note: not all potential problems can be ruled out).', 'green')
1409
1755
 
1410
1756
  def get_action(self, key: random.PRNGKey,
1411
- params: Dict,
1757
+ params: Pytree,
1412
1758
  step: int,
1413
- subs: Dict,
1414
- policy_hyperparams: Dict[str, object]=None) -> Dict[str, object]:
1759
+ subs: Dict[str, Any],
1760
+ policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
1415
1761
  '''Returns an action dictionary from the policy or plan with the given
1416
1762
  parameters.
1417
1763
 
1418
1764
  :param key: the JAX PRNG key
1419
1765
  :param params: the trainable parameter PyTree of the policy
1420
1766
  :param step: the time step at which decision is made
1421
- :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1422
- weights for sigmoid wrapping boolean actions
1423
1767
  :param subs: the dict of pvariables
1768
+ :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1769
+ weights for sigmoid wrapping boolean actions (optional)
1424
1770
  '''
1425
1771
 
1426
1772
  # check compatibility of the subs dictionary
1427
- for var in subs.keys():
1773
+ for (var, values) in subs.items():
1774
+
1775
+ # must not be grounded
1428
1776
  if RDDLPlanningModel.FLUENT_SEP in var \
1429
1777
  or RDDLPlanningModel.OBJECT_SEP in var:
1430
- raise Exception(f'State dictionary passed to the JAX policy is '
1431
- f'grounded, since it contains the key <{var}>, '
1432
- f'but a vectorized environment is required: '
1433
- f'please make sure vectorized=True in the RDDLEnv.')
1434
-
1778
+ raise ValueError(f'State dictionary passed to the JAX policy is '
1779
+ f'grounded, since it contains the key <{var}>, '
1780
+ f'but a vectorized environment is required: '
1781
+ f'please make sure vectorized=True in the RDDLEnv.')
1782
+
1783
+ # must be numeric array
1784
+ # exception is for POMDPs at 1st epoch when observ-fluents are None
1785
+ if not jnp.issubdtype(values.dtype, jnp.number) \
1786
+ and not jnp.issubdtype(values.dtype, jnp.bool_):
1787
+ if step == 0 and var in self.rddl.observ_fluents:
1788
+ subs[var] = self.test_compiled.init_values[var]
1789
+ else:
1790
+ raise ValueError(f'Values assigned to pvariable {var} are '
1791
+ f'non-numeric of type {values.dtype}: {values}.')
1792
+
1435
1793
  # cast device arrays to numpy
1436
1794
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
1437
1795
  actions = jax.tree_map(np.asarray, actions)
1438
1796
  return actions
1439
-
1440
- def _plot_actions(self, key, params, hyperparams, subs, it):
1441
- rddl = self.rddl
1442
- try:
1443
- import matplotlib.pyplot as plt
1444
- except Exception:
1445
- print('matplotlib is not installed, aborting plot...')
1446
- return
1447
-
1448
- # predict actions from the trained policy or plan
1449
- actions = self.test_rollouts(key, params, hyperparams, subs, {})['action']
1450
-
1451
- # plot the action sequences as color maps
1452
- fig, axs = plt.subplots(nrows=len(actions), constrained_layout=True)
1453
- for (ax, name) in zip(axs, actions):
1454
- action = np.mean(actions[name], axis=0, dtype=float)
1455
- action = np.reshape(action, newshape=(action.shape[0], -1)).T
1456
- if rddl.variable_ranges[name] == 'bool':
1457
- vmin, vmax = 0.0, 1.0
1458
- else:
1459
- vmin, vmax = None, None
1460
- img = ax.imshow(
1461
- action, vmin=vmin, vmax=vmax, cmap='seismic', aspect='auto')
1462
- ax.set_xlabel('time')
1463
- ax.set_ylabel(name)
1464
- plt.colorbar(img, ax=ax)
1465
-
1466
- # write plot to disk
1467
- plt.savefig(f'plan_{rddl.domain_name}_{rddl.instance_name}_{it}.pdf',
1468
- bbox_inches='tight')
1469
- plt.clf()
1470
- plt.close(fig)
1471
1797
 
1472
1798
 
1473
- class JaxArmijoLineSearchPlanner(JaxBackpropPlanner):
1799
+ class JaxLineSearchPlanner(JaxBackpropPlanner):
1474
1800
  '''A class for optimizing an action sequence in the given RDDL MDP using
1475
- Armijo linear search gradient descent.'''
1801
+ linear search gradient descent, with the Armijo condition.'''
1476
1802
 
1477
1803
  def __init__(self, *args,
1478
1804
  optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
1479
- optimizer_kwargs: Dict[str, object]={'learning_rate': 1.0},
1480
- beta: float=0.8,
1805
+ optimizer_kwargs: Kwargs={'learning_rate': 1.0},
1806
+ decay: float=0.8,
1481
1807
  c: float=0.1,
1482
- lrmax: float=1.0,
1483
- lrmin: float=1e-5,
1808
+ step_max: float=1.0,
1809
+ step_min: float=1e-6,
1484
1810
  **kwargs) -> None:
1485
1811
  '''Creates a new gradient-based algorithm for optimizing action sequences
1486
- (plan) in the given RDDL using Armijo line search. All arguments are the
1812
+ (plan) in the given RDDL using line search. All arguments are the
1487
1813
  same as in the parent class, except:
1488
1814
 
1489
- :param beta: reduction factor of learning rate per line search iteration
1490
- :param c: coefficient in Armijo condition
1491
- :param lrmax: initial learning rate for line search
1492
- :param lrmin: minimum possible learning rate (line search halts)
1815
+ :param decay: reduction factor of learning rate per line search iteration
1816
+ :param c: positive coefficient in Armijo condition, should be in (0, 1)
1817
+ :param step_max: initial learning rate for line search
1818
+ :param step_min: minimum possible learning rate (line search halts)
1493
1819
  '''
1494
- self.beta = beta
1820
+ self.decay = decay
1495
1821
  self.c = c
1496
- self.lrmax = lrmax
1497
- self.lrmin = lrmin
1498
- super(JaxArmijoLineSearchPlanner, self).__init__(
1822
+ self.step_max = step_max
1823
+ self.step_min = step_min
1824
+ if 'clip_grad' in kwargs:
1825
+ raise_warning('clip_grad parameter conflicts with '
1826
+ 'line search planner and will be ignored.', 'red')
1827
+ del kwargs['clip_grad']
1828
+ super(JaxLineSearchPlanner, self).__init__(
1499
1829
  *args,
1500
1830
  optimizer=optimizer,
1501
1831
  optimizer_kwargs=optimizer_kwargs,
1502
1832
  **kwargs)
1503
1833
 
1504
- def summarize_hyperparameters(self):
1505
- super(JaxArmijoLineSearchPlanner, self).summarize_hyperparameters()
1834
+ def summarize_hyperparameters(self) -> None:
1835
+ super(JaxLineSearchPlanner, self).summarize_hyperparameters()
1506
1836
  print(f'linesearch hyper-parameters:\n'
1507
- f' beta ={self.beta}\n'
1837
+ f' decay ={self.decay}\n'
1508
1838
  f' c ={self.c}\n'
1509
- f' lr_range=({self.lrmin}, {self.lrmax})\n')
1839
+ f' lr_range=({self.step_min}, {self.step_max})')
1510
1840
 
1511
1841
  def _jax_update(self, loss):
1512
1842
  optimizer = self.optimizer
1513
1843
  projection = self.plan.projection
1514
- beta, c, lrmax, lrmin = self.beta, self.c, self.lrmax, self.lrmin
1515
-
1516
- # continue line search if Armijo condition not satisfied and learning
1517
- # rate can be further reduced
1518
- def _jax_wrapped_line_search_armijo_check(val):
1519
- (_, old_f, _, old_norm_g2, _), (_, new_f, lr, _), _, _ = val
1520
- return jnp.logical_and(
1521
- new_f >= old_f - c * lr * old_norm_g2,
1522
- lr >= lrmin / beta)
1523
-
1524
- def _jax_wrapped_line_search_iteration(val):
1525
- old, new, best, aux = val
1526
- old_x, _, old_g, _, old_state = old
1527
- _, _, lr, iters = new
1528
- _, best_f, _, _ = best
1529
- key, hyperparams, *other = aux
1530
-
1531
- # anneal learning rate and apply a gradient step
1532
- new_lr = beta * lr
1533
- old_state.hyperparams['learning_rate'] = new_lr
1534
- updates, new_state = optimizer.update(old_g, old_state)
1535
- new_x = optax.apply_updates(old_x, updates)
1536
- new_x, _ = projection(new_x, hyperparams)
1537
-
1538
- # evaluate new loss and record best so far
1539
- new_f, _ = loss(key, new_x, hyperparams, *other)
1540
- new = (new_x, new_f, new_lr, iters + 1)
1541
- best = jax.lax.cond(
1542
- new_f < best_f,
1543
- lambda: (new_x, new_f, new_lr, new_state),
1544
- lambda: best
1545
- )
1546
- return old, new, best, aux
1844
+ decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
1845
+
1846
+ # initialize the line search routine
1847
+ @jax.jit
1848
+ def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
1849
+ subs, model_params):
1850
+ (f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
1851
+ key, policy_params, hyperparams, subs, model_params)
1852
+ gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
1853
+ gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
1854
+ log['grad'] = grad
1855
+ return f, grad, gnorm2, log
1547
1856
 
1857
+ # compute the next trial solution
1858
+ @jax.jit
1859
+ def _jax_wrapped_line_search_trial(
1860
+ step, grad, key, params, hparams, subs, mparams, state):
1861
+ state.hyperparams['learning_rate'] = step
1862
+ updates, new_state = optimizer.update(grad, state)
1863
+ new_params = optax.apply_updates(params, updates)
1864
+ new_params, _ = projection(new_params, hparams)
1865
+ f_step, _ = loss(key, new_params, hparams, subs, mparams)
1866
+ return f_step, new_params, new_state
1867
+
1868
+ # main iteration of line search
1548
1869
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1549
- subs, model_params, opt_state):
1550
-
1551
- # calculate initial loss value, gradient and squared norm
1552
- old_x = policy_params
1553
- loss_and_grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1554
- (old_f, log), old_g = loss_and_grad_fn(
1555
- key, old_x, hyperparams, subs, model_params)
1556
- old_norm_g2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), old_g)
1557
- old_norm_g2 = jax.tree_util.tree_reduce(jnp.add, old_norm_g2)
1558
- log['grad'] = old_g
1870
+ subs, model_params, opt_state, opt_aux):
1559
1871
 
1560
- # initialize learning rate to maximum
1561
- new_lr = lrmax / beta
1562
- old = (old_x, old_f, old_g, old_norm_g2, opt_state)
1563
- new = (old_x, old_f, new_lr, 0)
1564
- best = (old_x, jnp.inf, jnp.nan, opt_state)
1565
- aux = (key, hyperparams, subs, model_params)
1872
+ # initialize the line search
1873
+ f, grad, gnorm2, log = _jax_wrapped_line_search_init(
1874
+ key, policy_params, hyperparams, subs, model_params)
1566
1875
 
1567
- # do a single line search step with the initial learning rate
1568
- init_val = (old, new, best, aux)
1569
- init_val = _jax_wrapped_line_search_iteration(init_val)
1876
+ # continue to reduce the learning rate until the Armijo condition holds
1877
+ trials = 0
1878
+ step = lrmax / decay
1879
+ f_step = np.inf
1880
+ best_f, best_step, best_params, best_state = np.inf, None, None, None
1881
+ while f_step > f - c * step * gnorm2 and step * decay >= lrmin:
1882
+ trials += 1
1883
+ step *= decay
1884
+ f_step, new_params, new_state = _jax_wrapped_line_search_trial(
1885
+ step, grad, key, policy_params, hyperparams, subs,
1886
+ model_params, opt_state)
1887
+ if f_step < best_f:
1888
+ best_f, best_step, best_params, best_state = \
1889
+ f_step, step, new_params, new_state
1570
1890
 
1571
- # continue to anneal the learning rate until Armijo condition holds
1572
- # or the learning rate becomes too small, then use the best parameter
1573
- _, (*_, iters), (best_params, _, best_lr, best_state), _ = \
1574
- jax.lax.while_loop(
1575
- cond_fun=_jax_wrapped_line_search_armijo_check,
1576
- body_fun=_jax_wrapped_line_search_iteration,
1577
- init_val=init_val
1578
- )
1579
- best_state.hyperparams['learning_rate'] = best_lr
1580
1891
  log['updates'] = None
1581
- log['line_search_iters'] = iters
1582
- log['learning_rate'] = best_lr
1583
- return best_params, True, best_state, log
1892
+ log['line_search_iters'] = trials
1893
+ log['learning_rate'] = best_step
1894
+ return best_params, True, best_state, best_step, best_f, log
1584
1895
 
1585
1896
  return _jax_wrapped_plan_update
1586
1897
 
1587
-
1898
+
1899
+ # ***********************************************************************
1900
+ # ALL VERSIONS OF RISK FUNCTIONS
1901
+ #
1902
+ # Based on the original paper "A Distributional Framework for Risk-Sensitive
1903
+ # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
1904
+ #
1905
+ # Original risk functions:
1906
+ # - entropic utility
1907
+ # - mean-variance approximation
1908
+ # - conditional value at risk with straight-through gradient trick
1909
+ #
1910
+ # ***********************************************************************
1911
+
1912
+
1913
+ @jax.jit
1914
+ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1915
+ return (-1.0 / beta) * jax.scipy.special.logsumexp(
1916
+ -beta * returns, b=1.0 / returns.size)
1917
+
1918
+
1919
+ @jax.jit
1920
+ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1921
+ return jnp.mean(returns) - (beta / 2.0) * jnp.var(returns)
1922
+
1923
+
1924
+ @jax.jit
1925
+ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1926
+ alpha_mask = jax.lax.stop_gradient(
1927
+ returns <= jnp.percentile(returns, q=100 * alpha))
1928
+ return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
1929
+
1930
+
1931
+ # ***********************************************************************
1932
+ # ALL VERSIONS OF CONTROLLERS
1933
+ #
1934
+ # - offline controller is the straight-line planner
1935
+ # - online controller is the replanning mode
1936
+ #
1937
+ # ***********************************************************************
1938
+
1588
1939
  class JaxOfflineController(BaseAgent):
1589
1940
  '''A container class for a Jax policy trained offline.'''
1941
+
1590
1942
  use_tensor_obs = True
1591
1943
 
1592
- def __init__(self, planner: JaxBackpropPlanner, key: random.PRNGKey,
1593
- eval_hyperparams: Dict[str, object]=None,
1594
- params: Dict[str, object]=None,
1944
+ def __init__(self, planner: JaxBackpropPlanner,
1945
+ key: Optional[random.PRNGKey]=None,
1946
+ eval_hyperparams: Optional[Dict[str, Any]]=None,
1947
+ params: Optional[Pytree]=None,
1595
1948
  train_on_reset: bool=False,
1596
1949
  **train_kwargs) -> None:
1597
1950
  '''Creates a new JAX offline control policy that is trained once, then
1598
1951
  deployed later.
1599
1952
 
1600
1953
  :param planner: underlying planning algorithm for optimizing actions
1601
- :param key: the RNG key to seed randomness
1954
+ :param key: the RNG key to seed randomness (derives from clock if not
1955
+ provided)
1602
1956
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
1603
1957
  or whenever sample_action is called
1604
1958
  :param params: use the specified policy parameters instead of calling
@@ -1608,6 +1962,8 @@ class JaxOfflineController(BaseAgent):
1608
1962
  for optimization
1609
1963
  '''
1610
1964
  self.planner = planner
1965
+ if key is None:
1966
+ key = random.PRNGKey(round(time.time() * 1000))
1611
1967
  self.key = key
1612
1968
  self.eval_hyperparams = eval_hyperparams
1613
1969
  self.train_on_reset = train_on_reset
@@ -1616,17 +1972,18 @@ class JaxOfflineController(BaseAgent):
1616
1972
 
1617
1973
  self.step = 0
1618
1974
  if not self.train_on_reset and not self.params_given:
1619
- params = self.planner.optimize(key=self.key, **self.train_kwargs)
1975
+ callback = self.planner.optimize(key=self.key, **self.train_kwargs)
1976
+ params = callback['best_params']
1620
1977
  self.params = params
1621
1978
 
1622
- def sample_action(self, state):
1979
+ def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
1623
1980
  self.key, subkey = random.split(self.key)
1624
1981
  actions = self.planner.get_action(
1625
1982
  subkey, self.params, self.step, state, self.eval_hyperparams)
1626
1983
  self.step += 1
1627
1984
  return actions
1628
1985
 
1629
- def reset(self):
1986
+ def reset(self) -> None:
1630
1987
  self.step = 0
1631
1988
  if self.train_on_reset and not self.params_given:
1632
1989
  self.params = self.planner.optimize(key=self.key, **self.train_kwargs)
@@ -1635,41 +1992,51 @@ class JaxOfflineController(BaseAgent):
1635
1992
  class JaxOnlineController(BaseAgent):
1636
1993
  '''A container class for a Jax controller continuously updated using state
1637
1994
  feedback.'''
1995
+
1638
1996
  use_tensor_obs = True
1639
1997
 
1640
- def __init__(self, planner: JaxBackpropPlanner, key: random.PRNGKey,
1641
- eval_hyperparams: Dict=None, warm_start: bool=True,
1998
+ def __init__(self, planner: JaxBackpropPlanner,
1999
+ key: Optional[random.PRNGKey]=None,
2000
+ eval_hyperparams: Optional[Dict[str, Any]]=None,
2001
+ warm_start: bool=True,
1642
2002
  **train_kwargs) -> None:
1643
2003
  '''Creates a new JAX control policy that is trained online in a closed-
1644
2004
  loop fashion.
1645
2005
 
1646
2006
  :param planner: underlying planning algorithm for optimizing actions
1647
- :param key: the RNG key to seed randomness
2007
+ :param key: the RNG key to seed randomness (derives from clock if not
2008
+ provided)
1648
2009
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
1649
2010
  or whenever sample_action is called
2011
+ :param warm_start: whether to use the previous decision epoch final
2012
+ policy parameters to warm the next decision epoch
1650
2013
  :param **train_kwargs: any keyword arguments to be passed to the planner
1651
2014
  for optimization
1652
2015
  '''
1653
2016
  self.planner = planner
2017
+ if key is None:
2018
+ key = random.PRNGKey(round(time.time() * 1000))
1654
2019
  self.key = key
1655
2020
  self.eval_hyperparams = eval_hyperparams
1656
2021
  self.warm_start = warm_start
1657
2022
  self.train_kwargs = train_kwargs
1658
2023
  self.reset()
1659
2024
 
1660
- def sample_action(self, state):
2025
+ def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
1661
2026
  planner = self.planner
1662
- params = planner.optimize(
2027
+ callback = planner.optimize(
1663
2028
  key=self.key,
1664
2029
  guess=self.guess,
1665
2030
  subs=state,
1666
2031
  **self.train_kwargs)
2032
+ params = callback['best_params']
1667
2033
  self.key, subkey = random.split(self.key)
1668
- actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2034
+ actions = planner.get_action(
2035
+ subkey, params, 0, state, self.eval_hyperparams)
1669
2036
  if self.warm_start:
1670
2037
  self.guess = planner.plan.guess_next_epoch(params)
1671
2038
  return actions
1672
2039
 
1673
- def reset(self):
2040
+ def reset(self) -> None:
1674
2041
  self.guess = None
1675
2042