pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. pyRDDLGym_jax/__init__.py +1 -0
  2. pyRDDLGym_jax/core/compiler.py +444 -221
  3. pyRDDLGym_jax/core/logic.py +129 -62
  4. pyRDDLGym_jax/core/planner.py +965 -394
  5. pyRDDLGym_jax/core/simulator.py +5 -7
  6. pyRDDLGym_jax/core/tuning.py +29 -15
  7. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  8. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
  9. pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
  10. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  11. pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
  12. pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
  13. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  14. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  15. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  16. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  17. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  18. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  19. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  20. pyRDDLGym_jax/examples/run_gym.py +3 -7
  21. pyRDDLGym_jax/examples/run_plan.py +10 -5
  22. pyRDDLGym_jax/examples/run_scipy.py +61 -0
  23. pyRDDLGym_jax/examples/run_tune.py +8 -3
  24. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
  25. pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
  26. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
  27. pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
  28. pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
  29. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  30. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  34. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  35. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  36. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  37. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  38. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
  39. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  from ast import literal_eval
2
2
  from collections import deque
3
3
  import configparser
4
+ from enum import Enum
4
5
  import haiku as hk
5
6
  import jax
6
7
  import jax.numpy as jnp
@@ -12,12 +13,33 @@ import os
12
13
  import sys
13
14
  import termcolor
14
15
  import time
16
+ import traceback
15
17
  from tqdm import tqdm
16
- from typing import Callable, Dict, Generator, Set, Sequence, Tuple
18
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
17
19
 
20
+ Activation = Callable[[jnp.ndarray], jnp.ndarray]
21
+ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
22
+ Kwargs = Dict[str, Any]
23
+ Pytree = Any
24
+
25
+ from pyRDDLGym.core.debug.exception import raise_warning
26
+
27
+ from pyRDDLGym_jax import __version__
28
+
29
+ # try to import matplotlib, if failed then skip plotting
30
+ try:
31
+ import matplotlib
32
+ import matplotlib.pyplot as plt
33
+ matplotlib.use('TkAgg')
34
+ except Exception:
35
+ raise_warning('failed to import matplotlib: '
36
+ 'plotting functionality will be disabled.', 'red')
37
+ traceback.print_exc()
38
+ plt = None
39
+
18
40
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
41
+ from pyRDDLGym.core.debug.logger import Logger
19
42
  from pyRDDLGym.core.debug.exception import (
20
- raise_warning,
21
43
  RDDLNotImplementedError,
22
44
  RDDLUndefinedVariableError,
23
45
  RDDLTypeError
@@ -37,6 +59,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
37
59
  # - instantiate planner
38
60
  #
39
61
  # ***********************************************************************
62
+
40
63
  def _parse_config_file(path: str):
41
64
  if not os.path.isfile(path):
42
65
  raise FileNotFoundError(f'File {path} does not exist.')
@@ -59,51 +82,96 @@ def _parse_config_string(value: str):
59
82
  return config, args
60
83
 
61
84
 
85
+ def _getattr_any(packages, item):
86
+ for package in packages:
87
+ loaded = getattr(package, item, None)
88
+ if loaded is not None:
89
+ return loaded
90
+ return None
91
+
92
+
62
93
  def _load_config(config, args):
63
94
  model_args = {k: args[k] for (k, _) in config.items('Model')}
64
95
  planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
65
96
  train_args = {k: args[k] for (k, _) in config.items('Training')}
66
97
 
67
- train_args['key'] = jax.random.PRNGKey(train_args['key'])
68
-
69
98
  # read the model settings
70
- tnorm_name = model_args['tnorm']
71
- tnorm_kwargs = model_args['tnorm_kwargs']
72
- logic_name = model_args['logic']
73
- logic_kwargs = model_args['logic_kwargs']
99
+ logic_name = model_args.get('logic', 'FuzzyLogic')
100
+ logic_kwargs = model_args.get('logic_kwargs', {})
101
+ tnorm_name = model_args.get('tnorm', 'ProductTNorm')
102
+ tnorm_kwargs = model_args.get('tnorm_kwargs', {})
103
+ comp_name = model_args.get('complement', 'StandardComplement')
104
+ comp_kwargs = model_args.get('complement_kwargs', {})
105
+ compare_name = model_args.get('comparison', 'SigmoidComparison')
106
+ compare_kwargs = model_args.get('comparison_kwargs', {})
74
107
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
75
- planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
108
+ logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
109
+ logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
76
110
 
77
- # read the optimizer settings
111
+ # read the policy settings
78
112
  plan_method = planner_args.pop('method')
79
113
  plan_kwargs = planner_args.pop('method_kwargs', {})
80
114
 
81
- if 'initializer' in plan_kwargs: # weight initialization
82
- init_name = plan_kwargs['initializer']
83
- init_class = getattr(initializers, init_name)
84
- init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
85
- try:
86
- plan_kwargs['initializer'] = init_class(**init_kwargs)
87
- except:
88
- raise_warning(f'ignoring arguments for initializer <{init_name}>')
89
- plan_kwargs['initializer'] = init_class
90
-
91
- if 'activation' in plan_kwargs: # activation function
92
- plan_kwargs['activation'] = getattr(jax.nn, plan_kwargs['activation'])
115
+ # policy initialization
116
+ plan_initializer = plan_kwargs.get('initializer', None)
117
+ if plan_initializer is not None:
118
+ initializer = _getattr_any(
119
+ packages=[initializers, hk.initializers], item=plan_initializer)
120
+ if initializer is None:
121
+ raise_warning(
122
+ f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
123
+ del plan_kwargs['initializer']
124
+ else:
125
+ init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
126
+ try:
127
+ plan_kwargs['initializer'] = initializer(**init_kwargs)
128
+ except Exception as _:
129
+ raise_warning(
130
+ f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
131
+ plan_kwargs['initializer'] = initializer
132
+
133
+ # policy activation
134
+ plan_activation = plan_kwargs.get('activation', None)
135
+ if plan_activation is not None:
136
+ activation = _getattr_any(
137
+ packages=[jax.nn, jax.numpy], item=plan_activation)
138
+ if activation is None:
139
+ raise_warning(
140
+ f'Ignoring invalid activation <{plan_activation}>.', 'red')
141
+ del plan_kwargs['activation']
142
+ else:
143
+ plan_kwargs['activation'] = activation
93
144
 
145
+ # read the planner settings
146
+ planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
94
147
  planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
95
- planner_args['optimizer'] = getattr(optax, planner_args['optimizer'])
148
+
149
+ # planner optimizer
150
+ planner_optimizer = planner_args.get('optimizer', None)
151
+ if planner_optimizer is not None:
152
+ optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
153
+ if optimizer is None:
154
+ raise_warning(
155
+ f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
156
+ del planner_args['optimizer']
157
+ else:
158
+ planner_args['optimizer'] = optimizer
159
+
160
+ # read the optimize call settings
161
+ planner_key = train_args.get('key', None)
162
+ if planner_key is not None:
163
+ train_args['key'] = random.PRNGKey(planner_key)
96
164
 
97
165
  return planner_args, plan_kwargs, train_args
98
166
 
99
167
 
100
- def load_config(path: str) -> Tuple[Dict[str, object], ...]:
168
+ def load_config(path: str) -> Tuple[Kwargs, ...]:
101
169
  '''Loads a config file at the specified file path.'''
102
170
  config, args = _parse_config_file(path)
103
171
  return _load_config(config, args)
104
172
 
105
173
 
106
- def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
174
+ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
107
175
  '''Loads config file contents specified explicitly as a string value.'''
108
176
  config, args = _parse_config_string(value)
109
177
  return _load_config(config, args)
@@ -115,6 +183,20 @@ def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
115
183
  # - replace discrete ops in state dynamics/reward with differentiable ones
116
184
  #
117
185
  # ***********************************************************************
186
+
187
+ def _function_discrete_approx_named(logic):
188
+ jax_discrete, jax_param = logic.discrete()
189
+
190
+ def _jax_wrapped_discrete_calc_approx(key, prob, params):
191
+ sample = jax_discrete(key, prob, params)
192
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
193
+ jnp.all(prob >= 0),
194
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
195
+ return sample, out_of_bounds
196
+
197
+ return _jax_wrapped_discrete_calc_approx, jax_param
198
+
199
+
118
200
  class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
119
201
  '''Compiles a RDDL AST representation to an equivalent JAX representation.
120
202
  Unlike its parent class, this class treats all fluents as real-valued, and
@@ -124,7 +206,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
124
206
 
125
207
  def __init__(self, *args,
126
208
  logic: FuzzyLogic=FuzzyLogic(),
127
- cpfs_without_grad: Set=set(),
209
+ cpfs_without_grad: Optional[Set[str]]=None,
128
210
  **kwargs) -> None:
129
211
  '''Creates a new RDDL to Jax compiler, where operations that are not
130
212
  differentiable are converted to approximate forms that have defined
@@ -139,28 +221,37 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
139
221
  :param *kwargs: keyword arguments to pass to base compiler
140
222
  '''
141
223
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
224
+
142
225
  self.logic = logic
226
+ self.logic.set_use64bit(self.use64bit)
227
+ if cpfs_without_grad is None:
228
+ cpfs_without_grad = set()
143
229
  self.cpfs_without_grad = cpfs_without_grad
144
230
 
145
231
  # actions and CPFs must be continuous
146
- raise_warning(f'Initial values of pvariables will be cast to real.')
232
+ pvars_cast = set()
147
233
  for (var, values) in self.init_values.items():
148
234
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
235
+ if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
236
+ pvars_cast.add(var)
237
+ if pvars_cast:
238
+ raise_warning(f'JAX gradient compiler requires that initial values '
239
+ f'of p-variables {pvars_cast} be cast to float.')
149
240
 
150
241
  # overwrite basic operations with fuzzy ones
151
242
  self.RELATIONAL_OPS = {
152
- '>=': logic.greaterEqual(),
153
- '<=': logic.lessEqual(),
243
+ '>=': logic.greater_equal(),
244
+ '<=': logic.less_equal(),
154
245
  '<': logic.less(),
155
246
  '>': logic.greater(),
156
247
  '==': logic.equal(),
157
- '~=': logic.notEqual()
248
+ '~=': logic.not_equal()
158
249
  }
159
- self.LOGICAL_NOT = logic.Not()
250
+ self.LOGICAL_NOT = logic.logical_not()
160
251
  self.LOGICAL_OPS = {
161
- '^': logic.And(),
162
- '&': logic.And(),
163
- '|': logic.Or(),
252
+ '^': logic.logical_and(),
253
+ '&': logic.logical_and(),
254
+ '|': logic.logical_or(),
164
255
  '~': logic.xor(),
165
256
  '=>': logic.implies(),
166
257
  '<=>': logic.equiv()
@@ -169,15 +260,19 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
169
260
  self.AGGREGATION_OPS['exists'] = logic.exists()
170
261
  self.AGGREGATION_OPS['argmin'] = logic.argmin()
171
262
  self.AGGREGATION_OPS['argmax'] = logic.argmax()
172
- self.KNOWN_UNARY['sgn'] = logic.signum()
263
+ self.KNOWN_UNARY['sgn'] = logic.sgn()
173
264
  self.KNOWN_UNARY['floor'] = logic.floor()
174
265
  self.KNOWN_UNARY['ceil'] = logic.ceil()
175
266
  self.KNOWN_UNARY['round'] = logic.round()
176
267
  self.KNOWN_UNARY['sqrt'] = logic.sqrt()
177
- self.KNOWN_BINARY['div'] = logic.floorDiv()
268
+ self.KNOWN_BINARY['div'] = logic.div()
178
269
  self.KNOWN_BINARY['mod'] = logic.mod()
179
270
  self.KNOWN_BINARY['fmod'] = logic.mod()
180
-
271
+ self.IF_HELPER = logic.control_if()
272
+ self.SWITCH_HELPER = logic.control_switch()
273
+ self.BERNOULLI_HELPER = logic.bernoulli()
274
+ self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
275
+
181
276
  def _jax_stop_grad(self, jax_expr):
182
277
 
183
278
  def _jax_wrapped_stop_grad(x, params, key):
@@ -188,46 +283,33 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
188
283
  return _jax_wrapped_stop_grad
189
284
 
190
285
  def _compile_cpfs(self, info):
191
- raise_warning('CPFs outputs will be cast to real.')
286
+ cpfs_cast = set()
192
287
  jax_cpfs = {}
193
288
  for (_, cpfs) in self.levels.items():
194
289
  for cpf in cpfs:
195
290
  _, expr = self.rddl.cpfs[cpf]
196
291
  jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
292
+ if self.rddl.variable_ranges[cpf] != 'real':
293
+ cpfs_cast.add(cpf)
197
294
  if cpf in self.cpfs_without_grad:
198
- raise_warning(f'CPF <{cpf}> stops gradient.')
199
295
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
296
+
297
+ if cpfs_cast:
298
+ raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
299
+ f'{cpfs_cast} be cast to float.')
300
+ if self.cpfs_without_grad:
301
+ raise_warning(f'User requested that gradients not flow '
302
+ f'through CPFs {self.cpfs_without_grad}.')
200
303
  return jax_cpfs
201
304
 
202
- def _jax_if_helper(self):
203
- return self.logic.If()
204
-
205
- def _jax_switch_helper(self):
206
- return self.logic.Switch()
207
-
208
305
  def _jax_kron(self, expr, info):
209
306
  if self.logic.verbose:
210
- raise_warning('KronDelta will be ignored.')
211
-
307
+ raise_warning('JAX gradient compiler ignores KronDelta '
308
+ 'during compilation.')
212
309
  arg, = expr.args
213
310
  arg = self._jax(arg, info)
214
311
  return arg
215
312
 
216
- def _jax_bernoulli_helper(self):
217
- return self.logic.bernoulli()
218
-
219
- def _jax_discrete_helper(self):
220
- jax_discrete, jax_param = self.logic.discrete()
221
-
222
- def _jax_wrapped_discrete_calc_approx(key, prob, params):
223
- sample = jax_discrete(key, prob, params)
224
- out_of_bounds = jnp.logical_not(jnp.logical_and(
225
- jnp.all(prob >= 0),
226
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
227
- return sample, out_of_bounds
228
-
229
- return _jax_wrapped_discrete_calc_approx, jax_param
230
-
231
313
 
232
314
  # ***********************************************************************
233
315
  # ALL VERSIONS OF JAX PLANS
@@ -236,6 +318,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
236
318
  # - deep reactive policy
237
319
  #
238
320
  # ***********************************************************************
321
+
239
322
  class JaxPlan:
240
323
  '''Base class for all JAX policy representations.'''
241
324
 
@@ -244,16 +327,17 @@ class JaxPlan:
244
327
  self._train_policy = None
245
328
  self._test_policy = None
246
329
  self._projection = None
247
-
248
- def summarize_hyperparameters(self):
330
+ self.bounds = None
331
+
332
+ def summarize_hyperparameters(self) -> None:
249
333
  pass
250
334
 
251
335
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
252
- _bounds: Dict,
336
+ _bounds: Bounds,
253
337
  horizon: int) -> None:
254
338
  raise NotImplementedError
255
339
 
256
- def guess_next_epoch(self, params: Dict) -> Dict:
340
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
257
341
  raise NotImplementedError
258
342
 
259
343
  @property
@@ -289,7 +373,8 @@ class JaxPlan:
289
373
  self._projection = value
290
374
 
291
375
  def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
292
- user_bounds: Dict[str, object], horizon: int):
376
+ user_bounds: Bounds,
377
+ horizon: int):
293
378
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
294
379
  for (name, prange) in compiled.rddl.variable_ranges.items():
295
380
  if compiled.rddl.variable_types[name] != 'action-fluent':
@@ -298,7 +383,7 @@ class JaxPlan:
298
383
  # check invalid type
299
384
  if prange not in compiled.JAX_TYPES:
300
385
  raise RDDLTypeError(
301
- f'Invalid range <{prange}. of action-fluent <{name}>, '
386
+ f'Invalid range <{prange}> of action-fluent <{name}>, '
302
387
  f'must be one of {set(compiled.JAX_TYPES.keys())}.')
303
388
 
304
389
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
@@ -309,8 +394,8 @@ class JaxPlan:
309
394
  else:
310
395
  lower, upper = compiled.constraints.bounds[name]
311
396
  lower, upper = user_bounds.get(name, (lower, upper))
312
- lower = np.asarray(lower, dtype=np.float32)
313
- upper = np.asarray(upper, dtype=np.float32)
397
+ lower = np.asarray(lower, dtype=compiled.REAL)
398
+ upper = np.asarray(upper, dtype=compiled.REAL)
314
399
  lower_finite = np.isfinite(lower)
315
400
  upper_finite = np.isfinite(upper)
316
401
  bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
@@ -320,7 +405,7 @@ class JaxPlan:
320
405
  ~lower_finite & upper_finite,
321
406
  ~lower_finite & ~upper_finite]
322
407
  bounds[name] = (lower, upper)
323
- raise_warning(f'Bounds of action fluent <{name}> set to {bounds[name]}.')
408
+ raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
324
409
  return shapes, bounds, bounds_safe, cond_lists
325
410
 
326
411
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -336,7 +421,7 @@ class JaxStraightLinePlan(JaxPlan):
336
421
 
337
422
  def __init__(self, initializer: initializers.Initializer=initializers.normal(),
338
423
  wrap_sigmoid: bool=True,
339
- min_action_prob: float=1e-5,
424
+ min_action_prob: float=1e-6,
340
425
  wrap_non_bool: bool=False,
341
426
  wrap_softmax: bool=False,
342
427
  use_new_projection: bool=False,
@@ -362,6 +447,7 @@ class JaxStraightLinePlan(JaxPlan):
362
447
  use_new_projection = True
363
448
  '''
364
449
  super(JaxStraightLinePlan, self).__init__()
450
+
365
451
  self._initializer_base = initializer
366
452
  self._initializer = initializer
367
453
  self._wrap_sigmoid = wrap_sigmoid
@@ -371,10 +457,13 @@ class JaxStraightLinePlan(JaxPlan):
371
457
  self._use_new_projection = use_new_projection
372
458
  self._max_constraint_iter = max_constraint_iter
373
459
 
374
- def summarize_hyperparameters(self):
460
+ def summarize_hyperparameters(self) -> None:
461
+ bounds = '\n '.join(
462
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
375
463
  print(f'policy hyper-parameters:\n'
376
- f' initializer ={type(self._initializer_base).__name__}\n'
464
+ f' initializer ={self._initializer_base}\n'
377
465
  f'constraint-sat strategy (simple):\n'
466
+ f' parsed_action_bounds =\n {bounds}\n'
378
467
  f' wrap_sigmoid ={self._wrap_sigmoid}\n'
379
468
  f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
380
469
  f' wrap_non_bool ={self._wrap_non_bool}\n'
@@ -383,7 +472,8 @@ class JaxStraightLinePlan(JaxPlan):
383
472
  f' use_new_projection ={self._use_new_projection}')
384
473
 
385
474
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
386
- _bounds: Dict, horizon: int) -> None:
475
+ _bounds: Bounds,
476
+ horizon: int) -> None:
387
477
  rddl = compiled.rddl
388
478
 
389
479
  # calculate the correct action box bounds
@@ -423,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
423
513
  def _jax_bool_action_to_param(var, action, hyperparams):
424
514
  if wrap_sigmoid:
425
515
  weight = hyperparams[var]
426
- return (-1.0 / weight) * jnp.log1p(1.0 / action - 2.0)
516
+ return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
427
517
  else:
428
518
  return action
429
519
 
@@ -506,7 +596,7 @@ class JaxStraightLinePlan(JaxPlan):
506
596
  def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, subs):
507
597
  actions = {}
508
598
  for (var, param) in params.items():
509
- action = jnp.asarray(param[step, ...])
599
+ action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
510
600
  if var == bool_key:
511
601
  output = jax.nn.softmax(action)
512
602
  bool_actions = _jax_unstack_bool_from_softmax(output)
@@ -537,7 +627,7 @@ class JaxStraightLinePlan(JaxPlan):
537
627
  if 1 < allowed_actions < bool_action_count:
538
628
  raise RDDLNotImplementedError(
539
629
  f'Straight-line plans with wrap_softmax currently '
540
- f'do not support max-nondef-actions = {allowed_actions} > 1.')
630
+ f'do not support max-nondef-actions {allowed_actions} > 1.')
541
631
 
542
632
  # potentially apply projection but to non-bool actions only
543
633
  self.projection = _jax_wrapped_slp_project_to_box
@@ -668,14 +758,14 @@ class JaxStraightLinePlan(JaxPlan):
668
758
  for (var, shape) in shapes.items():
669
759
  if ranges[var] != 'bool' or not stack_bool_params:
670
760
  key, subkey = random.split(key)
671
- param = init(subkey, shape, dtype=compiled.REAL)
761
+ param = init(key=subkey, shape=shape, dtype=compiled.REAL)
672
762
  if ranges[var] == 'bool':
673
763
  param += bool_threshold
674
764
  params[var] = param
675
765
  if stack_bool_params:
676
766
  key, subkey = random.split(key)
677
767
  bool_shape = (horizon, bool_action_count)
678
- bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
768
+ bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
679
769
  params[bool_key] = bool_param
680
770
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
681
771
  return params
@@ -688,7 +778,7 @@ class JaxStraightLinePlan(JaxPlan):
688
778
  # "progress" the plan one step forward and set last action to second-last
689
779
  return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
690
780
 
691
- def guess_next_epoch(self, params: Dict) -> Dict:
781
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
692
782
  next_fn = JaxStraightLinePlan._guess_next_epoch
693
783
  return jax.tree_map(next_fn, params)
694
784
 
@@ -696,10 +786,13 @@ class JaxStraightLinePlan(JaxPlan):
696
786
  class JaxDeepReactivePolicy(JaxPlan):
697
787
  '''A deep reactive policy network implementation in JAX.'''
698
788
 
699
- def __init__(self, topology: Sequence[int],
700
- activation: Callable=jax.nn.relu,
789
+ def __init__(self, topology: Optional[Sequence[int]]=None,
790
+ activation: Activation=jnp.tanh,
701
791
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
702
- normalize: bool=True) -> None:
792
+ normalize: bool=False,
793
+ normalize_per_layer: bool=False,
794
+ normalizer_kwargs: Optional[Kwargs]=None,
795
+ wrap_non_bool: bool=False) -> None:
703
796
  '''Creates a new deep reactive policy in JAX.
704
797
 
705
798
  :param neurons: sequence consisting of the number of neurons in each
@@ -707,23 +800,45 @@ class JaxDeepReactivePolicy(JaxPlan):
707
800
  :param activation: function to apply after each layer of the policy
708
801
  :param initializer: weight initialization
709
802
  :param normalize: whether to apply layer norm to the inputs
803
+ :param normalize_per_layer: whether to apply layer norm to each input
804
+ individually (only active if normalize is True)
805
+ :param normalizer_kwargs: if normalize is True, apply additional arguments
806
+ to layer norm
807
+ :param wrap_non_bool: whether to wrap real or int action fluent parameters
808
+ with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
710
809
  '''
711
810
  super(JaxDeepReactivePolicy, self).__init__()
811
+
812
+ if topology is None:
813
+ topology = [128, 64]
712
814
  self._topology = topology
713
815
  self._activations = [activation for _ in topology]
714
816
  self._initializer_base = initializer
715
817
  self._initializer = initializer
716
818
  self._normalize = normalize
819
+ self._normalize_per_layer = normalize_per_layer
820
+ if normalizer_kwargs is None:
821
+ normalizer_kwargs = {'create_offset': True, 'create_scale': True}
822
+ self._normalizer_kwargs = normalizer_kwargs
823
+ self._wrap_non_bool = wrap_non_bool
717
824
 
718
- def summarize_hyperparameters(self):
825
+ def summarize_hyperparameters(self) -> None:
826
+ bounds = '\n '.join(
827
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
719
828
  print(f'policy hyper-parameters:\n'
720
- f' topology ={self._topology}\n'
721
- f' activation_fn ={self._activations[0].__name__}\n'
722
- f' initializer ={type(self._initializer_base).__name__}\n'
723
- f' apply_layer_norm={self._normalize}')
829
+ f' topology ={self._topology}\n'
830
+ f' activation_fn ={self._activations[0].__name__}\n'
831
+ f' initializer ={type(self._initializer_base).__name__}\n'
832
+ f' apply_input_norm ={self._normalize}\n'
833
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
834
+ f' input_norm_args ={self._normalizer_kwargs}\n'
835
+ f'constraint-sat strategy:\n'
836
+ f' parsed_action_bounds=\n {bounds}\n'
837
+ f' wrap_non_bool ={self._wrap_non_bool}')
724
838
 
725
839
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
726
- _bounds: Dict, horizon: int) -> None:
840
+ _bounds: Bounds,
841
+ horizon: int) -> None:
727
842
  rddl = compiled.rddl
728
843
 
729
844
  # calculate the correct action box bounds
@@ -737,7 +852,7 @@ class JaxDeepReactivePolicy(JaxPlan):
737
852
  if 1 < allowed_actions < bool_action_count:
738
853
  raise RDDLNotImplementedError(
739
854
  f'Deep reactive policies currently do not support '
740
- f'max-nondef-actions = {allowed_actions} > 1.')
855
+ f'max-nondef-actions {allowed_actions} > 1.')
741
856
  use_constraint_satisfaction = allowed_actions < bool_action_count
742
857
 
743
858
  noop = {var: (values[0] if isinstance(values, list) else values)
@@ -751,22 +866,75 @@ class JaxDeepReactivePolicy(JaxPlan):
751
866
 
752
867
  ranges = rddl.variable_ranges
753
868
  normalize = self._normalize
869
+ normalize_per_layer = self._normalize_per_layer
870
+ wrap_non_bool = self._wrap_non_bool
754
871
  init = self._initializer
755
872
  layers = list(enumerate(zip(self._topology, self._activations)))
756
873
  layer_sizes = {var: np.prod(shape, dtype=int)
757
874
  for (var, shape) in shapes.items()}
758
875
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
759
876
 
760
- # predict actions from the policy network for current state
761
- def _jax_wrapped_policy_network_predict(state):
877
+ # inputs for the policy network
878
+ if rddl.observ_fluents:
879
+ observed_vars = rddl.observ_fluents
880
+ else:
881
+ observed_vars = rddl.state_fluents
882
+ input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
883
+
884
+ # catch if input norm is applied to size 1 tensor
885
+ if normalize:
886
+ non_bool_dims = 0
887
+ for (var, values) in observed_vars.items():
888
+ if ranges[var] != 'bool':
889
+ value_size = np.atleast_1d(values).size
890
+ if normalize_per_layer and value_size == 1:
891
+ raise_warning(
892
+ f'Cannot apply layer norm to state-fluent <{var}> '
893
+ f'of size 1: setting normalize_per_layer = False.',
894
+ 'red')
895
+ normalize_per_layer = False
896
+ non_bool_dims += value_size
897
+ if not normalize_per_layer and non_bool_dims == 1:
898
+ raise_warning(
899
+ 'Cannot apply layer norm to state-fluents of total size 1: '
900
+ 'setting normalize = False.', 'red')
901
+ normalize = False
902
+
903
+ # convert subs dictionary into a state vector to feed to the MLP
904
+ def _jax_wrapped_policy_input(subs):
762
905
 
763
- # apply layer norm
764
- if normalize:
906
+ # concatenate all state variables into a single vector
907
+ # optionally apply layer norm to each input tensor
908
+ states_bool, states_non_bool = [], []
909
+ non_bool_dims = 0
910
+ for (var, value) in subs.items():
911
+ if var in observed_vars:
912
+ state = jnp.ravel(value)
913
+ if ranges[var] == 'bool':
914
+ states_bool.append(state)
915
+ else:
916
+ if normalize and normalize_per_layer:
917
+ normalizer = hk.LayerNorm(
918
+ axis=-1, param_axis=-1,
919
+ name=f'input_norm_{input_names[var]}',
920
+ **self._normalizer_kwargs)
921
+ state = normalizer(state)
922
+ states_non_bool.append(state)
923
+ non_bool_dims += state.size
924
+ state = jnp.concatenate(states_non_bool + states_bool)
925
+
926
+ # optionally perform layer normalization on the non-bool inputs
927
+ if normalize and not normalize_per_layer and non_bool_dims:
765
928
  normalizer = hk.LayerNorm(
766
- axis=-1, param_axis=-1,
767
- create_offset=True, create_scale=True,
768
- name='input_norm')
769
- state = normalizer(state)
929
+ axis=-1, param_axis=-1, name='input_norm',
930
+ **self._normalizer_kwargs)
931
+ normalized = normalizer(state[:non_bool_dims])
932
+ state = state.at[:non_bool_dims].set(normalized)
933
+ return state
934
+
935
+ # predict actions from the policy network for current state
936
+ def _jax_wrapped_policy_network_predict(subs):
937
+ state = _jax_wrapped_policy_input(subs)
770
938
 
771
939
  # feed state vector through hidden layers
772
940
  hidden = state
@@ -789,16 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
789
957
  if not use_constraint_satisfaction:
790
958
  actions[var] = jax.nn.sigmoid(output)
791
959
  else:
792
- lower, upper = bounds_safe[var]
793
- action = jnp.select(
794
- condlist=cond_lists[var],
795
- choicelist=[
796
- lower + (upper - lower) * jax.nn.sigmoid(output),
797
- lower + (jax.nn.elu(output) + 1.0),
798
- upper - (jax.nn.elu(-output) + 1.0),
799
- output
800
- ]
801
- )
960
+ if wrap_non_bool:
961
+ lower, upper = bounds_safe[var]
962
+ action = jnp.select(
963
+ condlist=cond_lists[var],
964
+ choicelist=[
965
+ lower + (upper - lower) * jax.nn.sigmoid(output),
966
+ lower + (jax.nn.elu(output) + 1.0),
967
+ upper - (jax.nn.elu(-output) + 1.0),
968
+ output
969
+ ]
970
+ )
971
+ else:
972
+ action = output
802
973
  actions[var] = action
803
974
 
804
975
  # for constraint satisfaction wrap bool actions with softmax
@@ -826,21 +997,14 @@ class JaxDeepReactivePolicy(JaxPlan):
826
997
  actions[name] = action
827
998
  start += size
828
999
  return actions
829
-
830
- # state is concatenated into single tensor
831
- def _jax_wrapped_subs_to_state(subs):
832
- subs = {var: value
833
- for (var, value) in subs.items()
834
- if var in rddl.state_fluents}
835
- flat_subs = jax.tree_map(jnp.ravel, subs)
836
- states = list(flat_subs.values())
837
- state = jnp.concatenate(states)
838
- return state
839
1000
 
840
1001
  # train action prediction
841
1002
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
842
- state = _jax_wrapped_subs_to_state(subs)
843
- actions = predict_fn.apply(params, state)
1003
+ actions = predict_fn.apply(params, subs)
1004
+ if not wrap_non_bool:
1005
+ for (var, action) in actions.items():
1006
+ if var != bool_key and ranges[var] != 'bool':
1007
+ actions[var] = jnp.clip(action, *bounds[var])
844
1008
  if use_constraint_satisfaction:
845
1009
  bool_actions = _jax_unstack_bool_from_softmax(actions[bool_key])
846
1010
  actions.update(bool_actions)
@@ -886,14 +1050,13 @@ class JaxDeepReactivePolicy(JaxPlan):
886
1050
  def _jax_wrapped_drp_init(key, hyperparams, subs):
887
1051
  subs = {var: value[0, ...]
888
1052
  for (var, value) in subs.items()
889
- if var in rddl.state_fluents}
890
- state = _jax_wrapped_subs_to_state(subs)
891
- params = predict_fn.init(key, state)
1053
+ if var in observed_vars}
1054
+ params = predict_fn.init(key, subs)
892
1055
  return params
893
1056
 
894
1057
  self.initializer = _jax_wrapped_drp_init
895
1058
 
896
- def guess_next_epoch(self, params: Dict) -> Dict:
1059
+ def guess_next_epoch(self, params: Pytree) -> Pytree:
897
1060
  return params
898
1061
 
899
1062
 
@@ -904,24 +1067,170 @@ class JaxDeepReactivePolicy(JaxPlan):
904
1067
  # - more stable but slower line search based planner
905
1068
  #
906
1069
  # ***********************************************************************
1070
+
1071
+ class RollingMean:
1072
+ '''Maintains an estimate of the rolling mean of a stream of real-valued
1073
+ observations.'''
1074
+
1075
+ def __init__(self, window_size: int) -> None:
1076
+ self._window_size = window_size
1077
+ self._memory = deque(maxlen=window_size)
1078
+ self._total = 0
1079
+
1080
+ def update(self, x: float) -> float:
1081
+ memory = self._memory
1082
+ self._total += x
1083
+ if len(memory) == self._window_size:
1084
+ self._total -= memory.popleft()
1085
+ memory.append(x)
1086
+ return self._total / len(memory)
1087
+
1088
+
1089
+ class JaxPlannerPlot:
1090
+ '''Supports plotting and visualization of a JAX policy in real time.'''
1091
+
1092
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1093
+ show_violin: bool=True, show_action: bool=True) -> None:
1094
+ '''Creates a new planner visualizer.
1095
+
1096
+ :param rddl: the planning model to optimize
1097
+ :param horizon: the lookahead or planning horizon
1098
+ :param show_violin: whether to show the distribution of batch losses
1099
+ :param show_action: whether to show heatmaps of the action fluents
1100
+ '''
1101
+ num_plots = 1
1102
+ if show_violin:
1103
+ num_plots += 1
1104
+ if show_action:
1105
+ num_plots += len(rddl.action_fluents)
1106
+ self._fig, axes = plt.subplots(num_plots)
1107
+ if num_plots == 1:
1108
+ axes = [axes]
1109
+
1110
+ # prepare the loss plot
1111
+ self._loss_ax = axes[0]
1112
+ self._loss_ax.autoscale(enable=True)
1113
+ self._loss_ax.set_xlabel('training time')
1114
+ self._loss_ax.set_ylabel('loss value')
1115
+ self._loss_plot = self._loss_ax.plot(
1116
+ [], [], linestyle=':', marker='o', markersize=2)[0]
1117
+ self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1118
+
1119
+ # prepare the violin plot
1120
+ if show_violin:
1121
+ self._hist_ax = axes[1]
1122
+ else:
1123
+ self._hist_ax = None
1124
+
1125
+ # prepare the action plots
1126
+ if show_action:
1127
+ self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1128
+ for (idx, name) in enumerate(rddl.action_fluents)}
1129
+ self._action_plots = {}
1130
+ for name in rddl.action_fluents:
1131
+ ax = self._action_ax[name]
1132
+ if rddl.variable_ranges[name] == 'bool':
1133
+ vmin, vmax = 0.0, 1.0
1134
+ else:
1135
+ vmin, vmax = None, None
1136
+ action_dim = 1
1137
+ for dim in rddl.object_counts(rddl.variable_params[name]):
1138
+ action_dim *= dim
1139
+ action_plot = ax.pcolormesh(
1140
+ np.zeros((action_dim, horizon)),
1141
+ cmap='seismic', vmin=vmin, vmax=vmax)
1142
+ ax.set_aspect('auto')
1143
+ ax.set_xlabel('decision epoch')
1144
+ ax.set_ylabel(name)
1145
+ plt.colorbar(action_plot, ax=ax)
1146
+ self._action_plots[name] = action_plot
1147
+ self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1148
+ for (name, ax) in self._action_ax.items()}
1149
+ else:
1150
+ self._action_ax = None
1151
+ self._action_plots = None
1152
+ self._action_back = None
1153
+
1154
+ plt.tight_layout()
1155
+ plt.show(block=False)
1156
+
1157
+ def redraw(self, xticks, losses, actions, returns) -> None:
1158
+
1159
+ # draw the loss curve
1160
+ self._fig.canvas.restore_region(self._loss_back)
1161
+ self._loss_plot.set_xdata(xticks)
1162
+ self._loss_plot.set_ydata(losses)
1163
+ self._loss_ax.set_xlim([0, len(xticks)])
1164
+ self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
1165
+ self._loss_ax.draw_artist(self._loss_plot)
1166
+ self._fig.canvas.blit(self._loss_ax.bbox)
1167
+
1168
+ # draw the violin plot
1169
+ if self._hist_ax is not None:
1170
+ self._hist_ax.clear()
1171
+ self._hist_ax.set_xlabel('loss value')
1172
+ self._hist_ax.set_ylabel('density')
1173
+ self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1174
+
1175
+ # draw the actions
1176
+ if self._action_ax is not None:
1177
+ for (name, values) in actions.items():
1178
+ values = np.mean(values, axis=0, dtype=float)
1179
+ values = np.reshape(values, newshape=(values.shape[0], -1)).T
1180
+ self._fig.canvas.restore_region(self._action_back[name])
1181
+ self._action_plots[name].set_array(values)
1182
+ self._action_ax[name].draw_artist(self._action_plots[name])
1183
+ self._fig.canvas.blit(self._action_ax[name].bbox)
1184
+ self._action_plots[name].set_clim([np.min(values), np.max(values)])
1185
+
1186
+ self._fig.canvas.draw()
1187
+ self._fig.canvas.flush_events()
1188
+
1189
+ def close(self) -> None:
1190
+ plt.close(self._fig)
1191
+ del self._loss_ax, self._hist_ax, self._action_ax, \
1192
+ self._loss_plot, self._action_plots, self._fig, \
1193
+ self._loss_back, self._action_back
1194
+
1195
+
1196
+ class JaxPlannerStatus(Enum):
1197
+ '''Represents the status of a policy update from the JAX planner,
1198
+ including whether the update resulted in nan gradient,
1199
+ whether progress was made, budget was reached, or other information that
1200
+ can be used to monitor and act based on the planner's progress.'''
1201
+
1202
+ NORMAL = 0
1203
+ NO_PROGRESS = 1
1204
+ PRECONDITION_POSSIBLY_UNSATISFIED = 2
1205
+ INVALID_GRADIENT = 3
1206
+ TIME_BUDGET_REACHED = 4
1207
+ ITER_BUDGET_REACHED = 5
1208
+
1209
+ def is_failure(self) -> bool:
1210
+ return self.value >= 3
1211
+
1212
+
907
1213
  class JaxBackpropPlanner:
908
1214
  '''A class for optimizing an action sequence in the given RDDL MDP using
909
1215
  gradient descent.'''
910
1216
 
911
1217
  def __init__(self, rddl: RDDLLiftedModel,
912
1218
  plan: JaxPlan,
913
- batch_size_train: int,
914
- batch_size_test: int=None,
915
- rollout_horizon: int=None,
1219
+ batch_size_train: int=32,
1220
+ batch_size_test: Optional[int]=None,
1221
+ rollout_horizon: Optional[int]=None,
916
1222
  use64bit: bool=False,
917
- action_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={},
1223
+ action_bounds: Optional[Bounds]=None,
918
1224
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
919
- optimizer_kwargs: Dict[str, object]={'learning_rate': 0.1},
920
- clip_grad: float=None,
1225
+ optimizer_kwargs: Optional[Kwargs]=None,
1226
+ clip_grad: Optional[float]=None,
921
1227
  logic: FuzzyLogic=FuzzyLogic(),
922
1228
  use_symlog_reward: bool=False,
923
- utility=jnp.mean,
924
- cpfs_without_grad: Set=set()) -> None:
1229
+ utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1230
+ utility_kwargs: Optional[Kwargs]=None,
1231
+ cpfs_without_grad: Optional[Set[str]]=None,
1232
+ compile_non_fluent_exact: bool=True,
1233
+ logger: Optional[Logger]=None) -> None:
925
1234
  '''Creates a new gradient-based algorithm for optimizing action sequences
926
1235
  (plan) in the given RDDL. Some operations will be converted to their
927
1236
  differentiable counterparts; the specific operations can be customized
@@ -946,9 +1255,16 @@ class JaxBackpropPlanner:
946
1255
  :param use_symlog_reward: whether to use the symlog transform on the
947
1256
  reward as a form of normalization
948
1257
  :param utility: how to aggregate return observations to compute utility
949
- of a policy or plan
1258
+ of a policy or plan; must be either a function mapping jax array to a
1259
+ scalar, or a a string identifying the utility function by name
1260
+ ("mean", "mean_var", "entropic", or "cvar" are currently supported)
1261
+ :param utility_kwargs: additional keyword arguments to pass hyper-
1262
+ parameters to the utility function call
950
1263
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
951
1264
  through gradient trick)
1265
+ :param compile_non_fluent_exact: whether non-fluent expressions
1266
+ are always compiled using exact JAX expressions
1267
+ :param logger: to log information about compilation to file
952
1268
  '''
953
1269
  self.rddl = rddl
954
1270
  self.plan = plan
@@ -959,22 +1275,25 @@ class JaxBackpropPlanner:
959
1275
  if rollout_horizon is None:
960
1276
  rollout_horizon = rddl.horizon
961
1277
  self.horizon = rollout_horizon
1278
+ if action_bounds is None:
1279
+ action_bounds = {}
962
1280
  self._action_bounds = action_bounds
963
1281
  self.use64bit = use64bit
964
1282
  self._optimizer_name = optimizer
1283
+ if optimizer_kwargs is None:
1284
+ optimizer_kwargs = {'learning_rate': 0.1}
965
1285
  self._optimizer_kwargs = optimizer_kwargs
966
1286
  self.clip_grad = clip_grad
967
1287
 
968
1288
  # set optimizer
969
1289
  try:
970
1290
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
971
- except:
1291
+ except Exception as _:
972
1292
  raise_warning(
973
1293
  'Failed to inject hyperparameters into optax optimizer, '
974
1294
  'rolling back to safer method: please note that modification of '
975
1295
  'optimizer hyperparameters will not work, and it is '
976
- 'recommended to update your packages and Python distribution.',
977
- 'red')
1296
+ 'recommended to update optax and related packages.', 'red')
978
1297
  optimizer = optimizer(**optimizer_kwargs)
979
1298
  if clip_grad is None:
980
1299
  self.optimizer = optimizer
@@ -983,33 +1302,84 @@ class JaxBackpropPlanner:
983
1302
  optax.clip(clip_grad),
984
1303
  optimizer
985
1304
  )
986
-
1305
+
1306
+ # set utility
1307
+ if isinstance(utility, str):
1308
+ utility = utility.lower()
1309
+ if utility == 'mean':
1310
+ utility_fn = jnp.mean
1311
+ elif utility == 'mean_var':
1312
+ utility_fn = mean_variance_utility
1313
+ elif utility == 'entropic':
1314
+ utility_fn = entropic_utility
1315
+ elif utility == 'cvar':
1316
+ utility_fn = cvar_utility
1317
+ else:
1318
+ raise RDDLNotImplementedError(
1319
+ f'Utility function <{utility}> is not supported: '
1320
+ 'must be one of ["mean", "mean_var", "entropic", "cvar"].')
1321
+ else:
1322
+ utility_fn = utility
1323
+ self.utility = utility_fn
1324
+
1325
+ if utility_kwargs is None:
1326
+ utility_kwargs = {}
1327
+ self.utility_kwargs = utility_kwargs
1328
+
987
1329
  self.logic = logic
1330
+ self.logic.set_use64bit(self.use64bit)
988
1331
  self.use_symlog_reward = use_symlog_reward
989
- self.utility = utility
1332
+ if cpfs_without_grad is None:
1333
+ cpfs_without_grad = set()
990
1334
  self.cpfs_without_grad = cpfs_without_grad
1335
+ self.compile_non_fluent_exact = compile_non_fluent_exact
1336
+ self.logger = logger
991
1337
 
992
1338
  self._jax_compile_rddl()
993
1339
  self._jax_compile_optimizer()
994
-
995
- def summarize_hyperparameters(self):
996
- print(f'objective and relaxations:\n'
997
- f' objective_fn ={self.utility.__name__}\n'
998
- f' use_symlog ={self.use_symlog_reward}\n'
999
- f' lookahead ={self.horizon}\n'
1000
- f' model relaxation={type(self.logic).__name__}\n'
1001
- f' action_bounds ={self._action_bounds}\n'
1002
- f' cpfs_no_gradient={self.cpfs_without_grad}\n'
1340
+
1341
+ def _summarize_system(self) -> None:
1342
+ try:
1343
+ jaxlib_version = jax._src.lib.version_str
1344
+ except Exception as _:
1345
+ jaxlib_version = 'N/A'
1346
+ try:
1347
+ devices_short = ', '.join(
1348
+ map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1349
+ except Exception as _:
1350
+ devices_short = 'N/A'
1351
+ print('\n'
1352
+ f'JAX Planner version {__version__}\n'
1353
+ f'Python {sys.version}\n'
1354
+ f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1355
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1356
+ f'numpy {np.__version__}\n'
1357
+ f'devices: {devices_short}\n')
1358
+
1359
+ def summarize_hyperparameters(self) -> None:
1360
+ print(f'objective hyper-parameters:\n'
1361
+ f' utility_fn ={self.utility.__name__}\n'
1362
+ f' utility args ={self.utility_kwargs}\n'
1363
+ f' use_symlog ={self.use_symlog_reward}\n'
1364
+ f' lookahead ={self.horizon}\n'
1365
+ f' user_action_bounds={self._action_bounds}\n'
1366
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1367
+ f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1368
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1003
1369
  f'optimizer hyper-parameters:\n'
1004
- f' use_64_bit ={self.use64bit}\n'
1005
- f' optimizer ={self._optimizer_name.__name__}\n'
1006
- f' optimizer args ={self._optimizer_kwargs}\n'
1007
- f' clip_gradient ={self.clip_grad}\n'
1008
- f' batch_size_train={self.batch_size_train}\n'
1009
- f' batch_size_test ={self.batch_size_test}')
1370
+ f' use_64_bit ={self.use64bit}\n'
1371
+ f' optimizer ={self._optimizer_name.__name__}\n'
1372
+ f' optimizer args ={self._optimizer_kwargs}\n'
1373
+ f' clip_gradient ={self.clip_grad}\n'
1374
+ f' batch_size_train ={self.batch_size_train}\n'
1375
+ f' batch_size_test ={self.batch_size_test}')
1010
1376
  self.plan.summarize_hyperparameters()
1011
1377
  self.logic.summarize_hyperparameters()
1012
1378
 
1379
+ # ===========================================================================
1380
+ # COMPILATION SUBROUTINES
1381
+ # ===========================================================================
1382
+
1013
1383
  def _jax_compile_rddl(self):
1014
1384
  rddl = self.rddl
1015
1385
 
@@ -1017,13 +1387,18 @@ class JaxBackpropPlanner:
1017
1387
  self.compiled = JaxRDDLCompilerWithGrad(
1018
1388
  rddl=rddl,
1019
1389
  logic=self.logic,
1390
+ logger=self.logger,
1020
1391
  use64bit=self.use64bit,
1021
- cpfs_without_grad=self.cpfs_without_grad)
1022
- self.compiled.compile()
1392
+ cpfs_without_grad=self.cpfs_without_grad,
1393
+ compile_non_fluent_exact=self.compile_non_fluent_exact)
1394
+ self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1023
1395
 
1024
1396
  # Jax compilation of the exact RDDL for testing
1025
- self.test_compiled = JaxRDDLCompiler(rddl=rddl, use64bit=self.use64bit)
1026
- self.test_compiled.compile()
1397
+ self.test_compiled = JaxRDDLCompiler(
1398
+ rddl=rddl,
1399
+ logger=self.logger,
1400
+ use64bit=self.use64bit)
1401
+ self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1027
1402
 
1028
1403
  def _jax_compile_optimizer(self):
1029
1404
 
@@ -1039,6 +1414,7 @@ class JaxBackpropPlanner:
1039
1414
  policy=self.plan.train_policy,
1040
1415
  n_steps=self.horizon,
1041
1416
  n_batch=self.batch_size_train)
1417
+ self.train_rollouts = train_rollouts
1042
1418
 
1043
1419
  test_rollouts = self.test_compiled.compile_rollouts(
1044
1420
  policy=self.plan.test_policy,
@@ -1051,11 +1427,10 @@ class JaxBackpropPlanner:
1051
1427
 
1052
1428
  # losses
1053
1429
  train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
1054
- self.train_loss = jax.jit(train_loss)
1055
1430
  self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
1056
1431
 
1057
1432
  # optimization
1058
- self.update = jax.jit(self._jax_update(train_loss))
1433
+ self.update = self._jax_update(train_loss)
1059
1434
 
1060
1435
  def _jax_return(self, use_symlog):
1061
1436
  gamma = self.rddl.discount
@@ -1068,13 +1443,14 @@ class JaxBackpropPlanner:
1068
1443
  rewards = rewards * discount[jnp.newaxis, ...]
1069
1444
  returns = jnp.sum(rewards, axis=1)
1070
1445
  if use_symlog:
1071
- returns = jnp.sign(returns) * jnp.log1p(jnp.abs(returns))
1446
+ returns = jnp.sign(returns) * jnp.log(1.0 + jnp.abs(returns))
1072
1447
  return returns
1073
1448
 
1074
1449
  return _jax_wrapped_returns
1075
1450
 
1076
1451
  def _jax_loss(self, rollouts, use_symlog=False):
1077
- utility_fn = self.utility
1452
+ utility_fn = self.utility
1453
+ utility_kwargs = self.utility_kwargs
1078
1454
  _jax_wrapped_returns = self._jax_return(use_symlog)
1079
1455
 
1080
1456
  # the loss is the average cumulative reward across all roll-outs
@@ -1083,7 +1459,7 @@ class JaxBackpropPlanner:
1083
1459
  log = rollouts(key, policy_params, hyperparams, subs, model_params)
1084
1460
  rewards = log['reward']
1085
1461
  returns = _jax_wrapped_returns(rewards)
1086
- utility = utility_fn(returns)
1462
+ utility = utility_fn(returns, **utility_kwargs)
1087
1463
  loss = -utility
1088
1464
  return loss, log
1089
1465
 
@@ -1096,7 +1472,7 @@ class JaxBackpropPlanner:
1096
1472
  def _jax_wrapped_init_policy(key, hyperparams, subs):
1097
1473
  policy_params = init(key, hyperparams, subs)
1098
1474
  opt_state = optimizer.init(policy_params)
1099
- return policy_params, opt_state
1475
+ return policy_params, opt_state, None
1100
1476
 
1101
1477
  return _jax_wrapped_init_policy
1102
1478
 
@@ -1107,17 +1483,18 @@ class JaxBackpropPlanner:
1107
1483
  # calculate the plan gradient w.r.t. return loss and update optimizer
1108
1484
  # also perform a projection step to satisfy constraints on actions
1109
1485
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1110
- subs, model_params, opt_state):
1111
- grad_fn = jax.grad(loss, argnums=1, has_aux=True)
1112
- grad, log = grad_fn(key, policy_params, hyperparams, subs, model_params)
1486
+ subs, model_params, opt_state, opt_aux):
1487
+ grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1488
+ (loss_val, log), grad = grad_fn(
1489
+ key, policy_params, hyperparams, subs, model_params)
1113
1490
  updates, opt_state = optimizer.update(grad, opt_state)
1114
1491
  policy_params = optax.apply_updates(policy_params, updates)
1115
1492
  policy_params, converged = projection(policy_params, hyperparams)
1116
1493
  log['grad'] = grad
1117
1494
  log['updates'] = updates
1118
- return policy_params, converged, opt_state, log
1495
+ return policy_params, converged, opt_state, None, loss_val, log
1119
1496
 
1120
- return _jax_wrapped_plan_update
1497
+ return jax.jit(_jax_wrapped_plan_update)
1121
1498
 
1122
1499
  def _batched_init_subs(self, subs):
1123
1500
  rddl = self.rddl
@@ -1145,15 +1522,106 @@ class JaxBackpropPlanner:
1145
1522
 
1146
1523
  return init_train, init_test
1147
1524
 
1148
- def optimize(self, *args, return_callback: bool=False, **kwargs) -> object:
1149
- ''' Compute an optimal straight-line plan. Returns the parameters
1150
- for the optimized policy.
1525
+ def as_optimization_problem(
1526
+ self, key: Optional[random.PRNGKey]=None,
1527
+ policy_hyperparams: Optional[Pytree]=None,
1528
+ loss_function_updates_key: bool=True,
1529
+ grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
1530
+ '''Returns a function that computes the loss and a function that
1531
+ computes gradient of the return as a 1D vector given a 1D representation
1532
+ of policy parameters. These functions are designed to be compatible with
1533
+ off-the-shelf optimizers such as scipy.
1534
+
1535
+ Also returns the initial parameter vector to seed an optimizer,
1536
+ as well as a mapping that recovers the parameter pytree from the vector.
1537
+ The PRNG key is updated internally starting from the optional given key.
1538
+
1539
+ Constraints on actions, if they are required, cannot be constructed
1540
+ automatically in the general case. The user should build constraints
1541
+ for each problem in the format required by the downstream optimizer.
1542
+
1543
+ :param key: JAX PRNG key (derived from clock if not provided)
1544
+ :param policy_hyperparameters: hyper-parameters for the policy/plan,
1545
+ such as weights for sigmoid wrapping boolean actions (defaults to 1
1546
+ for all action-fluents if not provided)
1547
+ :param loss_function_updates_key: if True, the loss function
1548
+ updates the PRNG key internally independently of the grad function
1549
+ :param grad_function_updates_key: if True, the gradient function
1550
+ updates the PRNG key internally independently of the loss function.
1551
+ '''
1151
1552
 
1152
- :param key: JAX PRNG key
1553
+ # if PRNG key is not provided
1554
+ if key is None:
1555
+ key = random.PRNGKey(round(time.time() * 1000))
1556
+
1557
+ # initialize the initial fluents, model parameters, policy hyper-params
1558
+ subs = self.test_compiled.init_values
1559
+ train_subs, _ = self._batched_init_subs(subs)
1560
+ model_params = self.compiled.model_params
1561
+ if policy_hyperparams is None:
1562
+ raise_warning('policy_hyperparams is not set, setting 1.0 for '
1563
+ 'all action-fluents which could be suboptimal.')
1564
+ policy_hyperparams = {action: 1.0
1565
+ for action in self.rddl.action_fluents}
1566
+
1567
+ # initialize the policy parameters
1568
+ params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
1569
+ guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
1570
+ guess_1d = np.asarray(guess_1d)
1571
+
1572
+ # computes the training loss function and its 1D gradient
1573
+ loss_fn = self._jax_loss(self.train_rollouts)
1574
+
1575
+ @jax.jit
1576
+ def _loss_with_key(key, params_1d):
1577
+ policy_params = unravel_fn(params_1d)
1578
+ loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1579
+ train_subs, model_params)
1580
+ return loss_val
1581
+
1582
+ @jax.jit
1583
+ def _grad_with_key(key, params_1d):
1584
+ policy_params = unravel_fn(params_1d)
1585
+ grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1586
+ grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1587
+ train_subs, model_params)
1588
+ grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1589
+ return grad_1d
1590
+
1591
+ def _loss_function(params_1d):
1592
+ nonlocal key
1593
+ if loss_function_updates_key:
1594
+ key, subkey = random.split(key)
1595
+ else:
1596
+ subkey = key
1597
+ loss_val = _loss_with_key(subkey, params_1d)
1598
+ loss_val = float(loss_val)
1599
+ return loss_val
1600
+
1601
+ def _grad_function(params_1d):
1602
+ nonlocal key
1603
+ if grad_function_updates_key:
1604
+ key, subkey = random.split(key)
1605
+ else:
1606
+ subkey = key
1607
+ grad = _grad_with_key(subkey, params_1d)
1608
+ grad = np.asarray(grad)
1609
+ return grad
1610
+
1611
+ return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1612
+
1613
+ # ===========================================================================
1614
+ # OPTIMIZE API
1615
+ # ===========================================================================
1616
+
1617
+ def optimize(self, *args, **kwargs) -> Dict[str, Any]:
1618
+ '''Compute an optimal policy or plan. Return the callback from training.
1619
+
1620
+ :param key: JAX PRNG key (derived from clock if not provided)
1153
1621
  :param epochs: the maximum number of steps of gradient descent
1154
- :param the maximum number of steps of gradient descent
1155
1622
  :param train_seconds: total time allocated for gradient descent
1156
1623
  :param plot_step: frequency to plot the plan and save result to disk
1624
+ :param plot_kwargs: additional arguments to pass to the plotter
1157
1625
  :param model_params: optional model-parameters to override default
1158
1626
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1159
1627
  weights for sigmoid wrapping boolean actions
@@ -1161,64 +1629,110 @@ class JaxBackpropPlanner:
1161
1629
  their values: if None initializes all variables from the RDDL instance
1162
1630
  :param guess: initial policy parameters: if None will use the initializer
1163
1631
  specified in this instance
1164
- :param verbose: not print (0), print summary (1), print progress (2)
1165
- :param return_callback: whether to return the callback from training
1166
- instead of the parameters
1632
+ :param print_summary: whether to print planner header, parameter
1633
+ summary, and diagnosis
1634
+ :param print_progress: whether to print the progress bar during training
1635
+ :param test_rolling_window: the test return is averaged on a rolling
1636
+ window of the past test_rolling_window returns when updating the best
1637
+ parameters found so far
1638
+ :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1167
1639
  '''
1168
1640
  it = self.optimize_generator(*args, **kwargs)
1169
- callback = deque(it, maxlen=1).pop()
1170
- if return_callback:
1171
- return callback
1641
+
1642
+ # if the python is C-compiled then the deque is native C and much faster
1643
+ # than naively exhausting iterator, but not if the python is some other
1644
+ # version (e.g. PyPi); for details, see
1645
+ # https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
1646
+ callback = None
1647
+ if sys.implementation.name == 'cpython':
1648
+ last_callback = deque(it, maxlen=1)
1649
+ if last_callback:
1650
+ callback = last_callback.pop()
1172
1651
  else:
1173
- return callback['best_params']
1652
+ for callback in it:
1653
+ pass
1654
+ return callback
1174
1655
 
1175
- def optimize_generator(self, key: random.PRNGKey,
1656
+ def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1176
1657
  epochs: int=999999,
1177
1658
  train_seconds: float=120.,
1178
- plot_step: int=None,
1179
- model_params: Dict[str, object]=None,
1180
- policy_hyperparams: Dict[str, object]=None,
1181
- subs: Dict[str, object]=None,
1182
- guess: Dict[str, object]=None,
1183
- verbose: int=2,
1184
- tqdm_position: int=None) -> Generator[Dict[str, object], None, None]:
1185
- '''Returns a generator for computing an optimal straight-line plan.
1659
+ plot_step: Optional[int]=None,
1660
+ plot_kwargs: Optional[Dict[str, Any]]=None,
1661
+ model_params: Optional[Dict[str, Any]]=None,
1662
+ policy_hyperparams: Optional[Dict[str, Any]]=None,
1663
+ subs: Optional[Dict[str, Any]]=None,
1664
+ guess: Optional[Pytree]=None,
1665
+ print_summary: bool=True,
1666
+ print_progress: bool=True,
1667
+ test_rolling_window: int=10,
1668
+ tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1669
+ '''Returns a generator for computing an optimal policy or plan.
1186
1670
  Generator can be iterated over to lazily optimize the plan, yielding
1187
1671
  a dictionary of intermediate computations.
1188
1672
 
1189
- :param key: JAX PRNG key
1673
+ :param key: JAX PRNG key (derived from clock if not provided)
1190
1674
  :param epochs: the maximum number of steps of gradient descent
1191
- :param the maximum number of steps of gradient descent
1192
1675
  :param train_seconds: total time allocated for gradient descent
1193
1676
  :param plot_step: frequency to plot the plan and save result to disk
1677
+ :param plot_kwargs: additional arguments to pass to the plotter
1194
1678
  :param model_params: optional model-parameters to override default
1195
1679
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1196
1680
  weights for sigmoid wrapping boolean actions
1197
1681
  :param subs: dictionary mapping initial state and non-fluents to
1198
1682
  their values: if None initializes all variables from the RDDL instance
1199
1683
  :param guess: initial policy parameters: if None will use the initializer
1200
- specified in this instance
1201
- :param verbose: not print (0), print summary (1), print progress (2)
1684
+ specified in this instance
1685
+ :param print_summary: whether to print planner header, parameter
1686
+ summary, and diagnosis
1687
+ :param print_progress: whether to print the progress bar during training
1688
+ :param test_rolling_window: the test return is averaged on a rolling
1689
+ window of the past test_rolling_window returns when updating the best
1690
+ parameters found so far
1202
1691
  :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1203
1692
  '''
1204
- verbose = int(verbose)
1205
1693
  start_time = time.time()
1206
1694
  elapsed_outside_loop = 0
1207
1695
 
1696
+ # if PRNG key is not provided
1697
+ if key is None:
1698
+ key = random.PRNGKey(round(time.time() * 1000))
1699
+
1700
+ # if policy_hyperparams is not provided
1701
+ if policy_hyperparams is None:
1702
+ raise_warning('policy_hyperparams is not set, setting 1.0 for '
1703
+ 'all action-fluents which could be suboptimal.')
1704
+ policy_hyperparams = {action: 1.0
1705
+ for action in self.rddl.action_fluents}
1706
+
1707
+ # if policy_hyperparams is a scalar
1708
+ elif isinstance(policy_hyperparams, (int, float, np.number)):
1709
+ raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
1710
+ 'setting this value for all action-fluents.')
1711
+ hyperparam_value = float(policy_hyperparams)
1712
+ policy_hyperparams = {action: hyperparam_value
1713
+ for action in self.rddl.action_fluents}
1714
+
1208
1715
  # print summary of parameters:
1209
- if verbose >= 1:
1210
- print('==============================================\n'
1211
- 'JAX PLANNER PARAMETER SUMMARY\n'
1212
- '==============================================')
1716
+ if print_summary:
1717
+ self._summarize_system()
1213
1718
  self.summarize_hyperparameters()
1214
1719
  print(f'optimize() call hyper-parameters:\n'
1720
+ f' PRNG key ={key}\n'
1215
1721
  f' max_iterations ={epochs}\n'
1216
1722
  f' max_seconds ={train_seconds}\n'
1217
1723
  f' model_params ={model_params}\n'
1218
1724
  f' policy_hyper_params={policy_hyperparams}\n'
1219
1725
  f' override_subs_dict ={subs is not None}\n'
1220
- f' provide_param_guess={guess is not None}\n'
1221
- f' plot_frequency ={plot_step}\n')
1726
+ f' provide_param_guess={guess is not None}\n'
1727
+ f' test_rolling_window={test_rolling_window}\n'
1728
+ f' plot_frequency ={plot_step}\n'
1729
+ f' plot_kwargs ={plot_kwargs}\n'
1730
+ f' print_summary ={print_summary}\n'
1731
+ f' print_progress ={print_progress}\n')
1732
+ if self.compiled.relaxations:
1733
+ print('Some RDDL operations are non-differentiable, '
1734
+ 'replacing them with differentiable relaxations:')
1735
+ print(self.compiled.summarize_model_relaxations())
1222
1736
 
1223
1737
  # compute a batched version of the initial values
1224
1738
  if subs is None:
@@ -1237,7 +1751,7 @@ class JaxBackpropPlanner:
1237
1751
  'from the RDDL files.')
1238
1752
  train_subs, test_subs = self._batched_init_subs(subs)
1239
1753
 
1240
- # initialize, model parameters
1754
+ # initialize model parameters
1241
1755
  if model_params is None:
1242
1756
  model_params = self.compiled.model_params
1243
1757
  model_params_test = self.test_compiled.model_params
@@ -1245,63 +1759,103 @@ class JaxBackpropPlanner:
1245
1759
  # initialize policy parameters
1246
1760
  if guess is None:
1247
1761
  key, subkey = random.split(key)
1248
- policy_params, opt_state = self.initialize(
1762
+ policy_params, opt_state, opt_aux = self.initialize(
1249
1763
  subkey, policy_hyperparams, train_subs)
1250
1764
  else:
1251
1765
  policy_params = guess
1252
1766
  opt_state = self.optimizer.init(policy_params)
1767
+ opt_aux = None
1768
+
1769
+ # initialize running statistics
1253
1770
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
1254
1771
  last_iter_improve = 0
1772
+ rolling_test_loss = RollingMean(test_rolling_window)
1255
1773
  log = {}
1774
+ status = JaxPlannerStatus.NORMAL
1775
+
1776
+ # initialize plot area
1777
+ if plot_step is None or plot_step <= 0 or plt is None:
1778
+ plot = None
1779
+ else:
1780
+ if plot_kwargs is None:
1781
+ plot_kwargs = {}
1782
+ plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1783
+ xticks, loss_values = [], []
1256
1784
 
1257
1785
  # training loop
1258
1786
  iters = range(epochs)
1259
- if verbose >= 2:
1787
+ if print_progress:
1260
1788
  iters = tqdm(iters, total=100, position=tqdm_position)
1261
1789
 
1262
1790
  for it in iters:
1791
+ status = JaxPlannerStatus.NORMAL
1263
1792
 
1264
1793
  # update the parameters of the plan
1265
- key, subkey1, subkey2, subkey3 = random.split(key, num=4)
1266
- policy_params, converged, opt_state, train_log = self.update(
1267
- subkey1, policy_params, policy_hyperparams,
1268
- train_subs, model_params, opt_state)
1794
+ key, subkey = random.split(key)
1795
+ policy_params, converged, opt_state, opt_aux, \
1796
+ train_loss, train_log = \
1797
+ self.update(subkey, policy_params, policy_hyperparams,
1798
+ train_subs, model_params, opt_state, opt_aux)
1799
+
1800
+ # no progress
1801
+ grad_norm_zero, _ = jax.tree_util.tree_flatten(
1802
+ jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1803
+ if np.all(grad_norm_zero):
1804
+ status = JaxPlannerStatus.NO_PROGRESS
1805
+
1806
+ # constraint satisfaction problem
1269
1807
  if not np.all(converged):
1270
1808
  raise_warning(
1271
1809
  'Projected gradient method for satisfying action concurrency '
1272
1810
  'constraints reached the iteration limit: plan is possibly '
1273
1811
  'invalid for the current instance.', 'red')
1812
+ status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1274
1813
 
1275
- # evaluate losses
1276
- train_loss, _ = self.train_loss(
1277
- subkey2, policy_params, policy_hyperparams,
1278
- train_subs, model_params)
1814
+ # numerical error
1815
+ if not np.isfinite(train_loss):
1816
+ raise_warning(
1817
+ f'Aborting JAX planner due to invalid train loss {train_loss}.',
1818
+ 'red')
1819
+ status = JaxPlannerStatus.INVALID_GRADIENT
1820
+
1821
+ # evaluate test losses and record best plan so far
1279
1822
  test_loss, log = self.test_loss(
1280
- subkey3, policy_params, policy_hyperparams,
1823
+ subkey, policy_params, policy_hyperparams,
1281
1824
  test_subs, model_params_test)
1282
-
1283
- # record the best plan so far
1825
+ test_loss = rolling_test_loss.update(test_loss)
1284
1826
  if test_loss < best_loss:
1285
1827
  best_params, best_loss, best_grad = \
1286
1828
  policy_params, test_loss, train_log['grad']
1287
1829
  last_iter_improve = it
1288
1830
 
1289
1831
  # save the plan figure
1290
- if plot_step is not None and it % plot_step == 0:
1291
- self._plot_actions(
1292
- key, policy_params, policy_hyperparams, test_subs, it)
1832
+ if plot is not None and it % plot_step == 0:
1833
+ xticks.append(it // plot_step)
1834
+ loss_values.append(test_loss.item())
1835
+ action_values = {name: values
1836
+ for (name, values) in log['fluents'].items()
1837
+ if name in self.rddl.action_fluents}
1838
+ returns = -np.sum(np.asarray(log['reward']), axis=1)
1839
+ plot.redraw(xticks, loss_values, action_values, returns)
1293
1840
 
1294
1841
  # if the progress bar is used
1295
1842
  elapsed = time.time() - start_time - elapsed_outside_loop
1296
- if verbose >= 2:
1843
+ if print_progress:
1297
1844
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1298
1845
  iters.set_description(
1299
- f'[{tqdm_position}] {it:6} it / {-train_loss:14.4f} train / '
1300
- f'{-test_loss:14.4f} test / {-best_loss:14.4f} best')
1846
+ f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
1847
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
1848
+
1849
+ # reached computation budget
1850
+ if elapsed >= train_seconds:
1851
+ status = JaxPlannerStatus.TIME_BUDGET_REACHED
1852
+ if it >= epochs - 1:
1853
+ status = JaxPlannerStatus.ITER_BUDGET_REACHED
1301
1854
 
1302
1855
  # return a callback
1303
1856
  start_time_outside = time.time()
1304
1857
  yield {
1858
+ 'status': status,
1305
1859
  'iteration': it,
1306
1860
  'train_return':-train_loss,
1307
1861
  'test_return':-test_loss,
@@ -1318,16 +1872,15 @@ class JaxBackpropPlanner:
1318
1872
  }
1319
1873
  elapsed_outside_loop += (time.time() - start_time_outside)
1320
1874
 
1321
- # reached time budget
1322
- if elapsed >= train_seconds:
1323
- break
1324
-
1325
- # numerical error
1326
- if not np.isfinite(train_loss):
1875
+ # abortion check
1876
+ if status.is_failure():
1327
1877
  break
1328
-
1329
- if verbose >= 2:
1878
+
1879
+ # release resources
1880
+ if print_progress:
1330
1881
  iters.close()
1882
+ if plot is not None:
1883
+ plot.close()
1331
1884
 
1332
1885
  # validate the test return
1333
1886
  if log:
@@ -1337,24 +1890,23 @@ class JaxBackpropPlanner:
1337
1890
  if messages:
1338
1891
  messages = '\n'.join(messages)
1339
1892
  raise_warning('The JAX compiler encountered the following '
1340
- 'problems in the original RDDL '
1893
+ 'error(s) in the original RDDL formulation '
1341
1894
  f'during test evaluation:\n{messages}', 'red')
1342
1895
 
1343
1896
  # summarize and test for convergence
1344
- if verbose >= 1:
1345
- grad_norm = jax.tree_map(
1346
- lambda x: np.array(jnp.linalg.norm(x)).item(), best_grad)
1897
+ if print_summary:
1898
+ grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1347
1899
  diagnosis = self._perform_diagnosis(
1348
- last_iter_improve, it,
1349
- -train_loss, -test_loss, -best_loss, grad_norm)
1900
+ last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
1350
1901
  print(f'summary of optimization:\n'
1902
+ f' status_code ={status}\n'
1351
1903
  f' time_elapsed ={elapsed}\n'
1352
1904
  f' iterations ={it}\n'
1353
1905
  f' best_objective={-best_loss}\n'
1354
- f' grad_norm ={grad_norm}\n'
1906
+ f' best_grad_norm={grad_norm}\n'
1355
1907
  f'diagnosis: {diagnosis}\n')
1356
1908
 
1357
- def _perform_diagnosis(self, last_iter_improve, total_it,
1909
+ def _perform_diagnosis(self, last_iter_improve,
1358
1910
  train_return, test_return, best_return, grad_norm):
1359
1911
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
1360
1912
  grad_is_zero = np.allclose(max_grad_norm, 0)
@@ -1373,20 +1925,20 @@ class JaxBackpropPlanner:
1373
1925
  if grad_is_zero:
1374
1926
  return termcolor.colored(
1375
1927
  '[FAILURE] no progress was made, '
1376
- f'and max grad norm = {max_grad_norm}, '
1377
- 'likely stuck in a plateau.', 'red')
1928
+ f'and max grad norm {max_grad_norm:.6f} is zero: '
1929
+ 'solver likely stuck in a plateau.', 'red')
1378
1930
  else:
1379
1931
  return termcolor.colored(
1380
1932
  '[FAILURE] no progress was made, '
1381
- f'but max grad norm = {max_grad_norm} > 0, '
1382
- 'likely due to bad l.r. or other hyper-parameter.', 'red')
1933
+ f'but max grad norm {max_grad_norm:.6f} is non-zero: '
1934
+ 'likely poor learning rate or other hyper-parameter.', 'red')
1383
1935
 
1384
1936
  # model is likely poor IF:
1385
1937
  # 1. the train and test return disagree
1386
1938
  if not (validation_error < 20):
1387
1939
  return termcolor.colored(
1388
1940
  '[WARNING] progress was made, '
1389
- f'but relative train test error = {validation_error} is high, '
1941
+ f'but relative train-test error {validation_error:.6f} is high: '
1390
1942
  'likely poor model relaxation around the solution, '
1391
1943
  'or the batch size is too small.', 'yellow')
1392
1944
 
@@ -1397,208 +1949,213 @@ class JaxBackpropPlanner:
1397
1949
  if not (return_to_grad_norm > 1):
1398
1950
  return termcolor.colored(
1399
1951
  '[WARNING] progress was made, '
1400
- f'but max grad norm = {max_grad_norm} is high, '
1401
- 'likely indicates the solution is not locally optimal, '
1402
- 'or the model is not smooth around the solution, '
1952
+ f'but max grad norm {max_grad_norm:.6f} is high: '
1953
+ 'likely the solution is not locally optimal, '
1954
+ 'or the relaxed model is not smooth around the solution, '
1403
1955
  'or the batch size is too small.', 'yellow')
1404
1956
 
1405
1957
  # likely successful
1406
1958
  return termcolor.colored(
1407
- '[SUCCESS] planner appears to have converged successfully '
1959
+ '[SUCCESS] planner has converged successfully '
1408
1960
  '(note: not all potential problems can be ruled out).', 'green')
1409
1961
 
1410
1962
  def get_action(self, key: random.PRNGKey,
1411
- params: Dict,
1963
+ params: Pytree,
1412
1964
  step: int,
1413
- subs: Dict,
1414
- policy_hyperparams: Dict[str, object]=None) -> Dict[str, object]:
1965
+ subs: Dict[str, Any],
1966
+ policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
1415
1967
  '''Returns an action dictionary from the policy or plan with the given
1416
1968
  parameters.
1417
1969
 
1418
1970
  :param key: the JAX PRNG key
1419
1971
  :param params: the trainable parameter PyTree of the policy
1420
1972
  :param step: the time step at which decision is made
1421
- :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1422
- weights for sigmoid wrapping boolean actions
1423
1973
  :param subs: the dict of pvariables
1974
+ :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1975
+ weights for sigmoid wrapping boolean actions (optional)
1424
1976
  '''
1425
1977
 
1426
1978
  # check compatibility of the subs dictionary
1427
- for var in subs.keys():
1979
+ for (var, values) in subs.items():
1980
+
1981
+ # must not be grounded
1428
1982
  if RDDLPlanningModel.FLUENT_SEP in var \
1429
1983
  or RDDLPlanningModel.OBJECT_SEP in var:
1430
- raise Exception(f'State dictionary passed to the JAX policy is '
1431
- f'grounded, since it contains the key <{var}>, '
1432
- f'but a vectorized environment is required: '
1433
- f'please make sure vectorized=True in the RDDLEnv.')
1434
-
1984
+ raise ValueError(f'State dictionary passed to the JAX policy is '
1985
+ f'grounded, since it contains the key <{var}>, '
1986
+ f'but a vectorized environment is required: '
1987
+ f'make sure vectorized = True in the RDDLEnv.')
1988
+
1989
+ # must be numeric array
1990
+ # exception is for POMDPs at 1st epoch when observ-fluents are None
1991
+ dtype = np.atleast_1d(values).dtype
1992
+ if not jnp.issubdtype(dtype, jnp.number) \
1993
+ and not jnp.issubdtype(dtype, jnp.bool_):
1994
+ if step == 0 and var in self.rddl.observ_fluents:
1995
+ subs[var] = self.test_compiled.init_values[var]
1996
+ else:
1997
+ raise ValueError(
1998
+ f'Values {values} assigned to p-variable <{var}> are '
1999
+ f'non-numeric of type {dtype}.')
2000
+
1435
2001
  # cast device arrays to numpy
1436
2002
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
1437
2003
  actions = jax.tree_map(np.asarray, actions)
1438
2004
  return actions
1439
-
1440
- def _plot_actions(self, key, params, hyperparams, subs, it):
1441
- rddl = self.rddl
1442
- try:
1443
- import matplotlib.pyplot as plt
1444
- except Exception:
1445
- print('matplotlib is not installed, aborting plot...')
1446
- return
1447
-
1448
- # predict actions from the trained policy or plan
1449
- actions = self.test_rollouts(key, params, hyperparams, subs, {})['action']
1450
-
1451
- # plot the action sequences as color maps
1452
- fig, axs = plt.subplots(nrows=len(actions), constrained_layout=True)
1453
- for (ax, name) in zip(axs, actions):
1454
- action = np.mean(actions[name], axis=0, dtype=float)
1455
- action = np.reshape(action, newshape=(action.shape[0], -1)).T
1456
- if rddl.variable_ranges[name] == 'bool':
1457
- vmin, vmax = 0.0, 1.0
1458
- else:
1459
- vmin, vmax = None, None
1460
- img = ax.imshow(
1461
- action, vmin=vmin, vmax=vmax, cmap='seismic', aspect='auto')
1462
- ax.set_xlabel('time')
1463
- ax.set_ylabel(name)
1464
- plt.colorbar(img, ax=ax)
1465
-
1466
- # write plot to disk
1467
- plt.savefig(f'plan_{rddl.domain_name}_{rddl.instance_name}_{it}.pdf',
1468
- bbox_inches='tight')
1469
- plt.clf()
1470
- plt.close(fig)
1471
2005
 
1472
2006
 
1473
- class JaxArmijoLineSearchPlanner(JaxBackpropPlanner):
2007
+ class JaxLineSearchPlanner(JaxBackpropPlanner):
1474
2008
  '''A class for optimizing an action sequence in the given RDDL MDP using
1475
- Armijo linear search gradient descent.'''
2009
+ linear search gradient descent, with the Armijo condition.'''
1476
2010
 
1477
2011
  def __init__(self, *args,
1478
- optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
1479
- optimizer_kwargs: Dict[str, object]={'learning_rate': 1.0},
1480
- beta: float=0.8,
2012
+ decay: float=0.8,
1481
2013
  c: float=0.1,
1482
- lrmax: float=1.0,
1483
- lrmin: float=1e-5,
2014
+ step_max: float=1.0,
2015
+ step_min: float=1e-6,
1484
2016
  **kwargs) -> None:
1485
2017
  '''Creates a new gradient-based algorithm for optimizing action sequences
1486
- (plan) in the given RDDL using Armijo line search. All arguments are the
2018
+ (plan) in the given RDDL using line search. All arguments are the
1487
2019
  same as in the parent class, except:
1488
2020
 
1489
- :param beta: reduction factor of learning rate per line search iteration
1490
- :param c: coefficient in Armijo condition
1491
- :param lrmax: initial learning rate for line search
1492
- :param lrmin: minimum possible learning rate (line search halts)
2021
+ :param decay: reduction factor of learning rate per line search iteration
2022
+ :param c: positive coefficient in Armijo condition, should be in (0, 1)
2023
+ :param step_max: initial learning rate for line search
2024
+ :param step_min: minimum possible learning rate (line search halts)
1493
2025
  '''
1494
- self.beta = beta
2026
+ self.decay = decay
1495
2027
  self.c = c
1496
- self.lrmax = lrmax
1497
- self.lrmin = lrmin
1498
- super(JaxArmijoLineSearchPlanner, self).__init__(
1499
- *args,
1500
- optimizer=optimizer,
1501
- optimizer_kwargs=optimizer_kwargs,
1502
- **kwargs)
1503
-
1504
- def summarize_hyperparameters(self):
1505
- super(JaxArmijoLineSearchPlanner, self).summarize_hyperparameters()
2028
+ self.step_max = step_max
2029
+ self.step_min = step_min
2030
+ if 'clip_grad' in kwargs:
2031
+ raise_warning('clip_grad parameter conflicts with '
2032
+ 'line search planner and will be ignored.', 'red')
2033
+ del kwargs['clip_grad']
2034
+ super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
2035
+
2036
+ def summarize_hyperparameters(self) -> None:
2037
+ super(JaxLineSearchPlanner, self).summarize_hyperparameters()
1506
2038
  print(f'linesearch hyper-parameters:\n'
1507
- f' beta ={self.beta}\n'
2039
+ f' decay ={self.decay}\n'
1508
2040
  f' c ={self.c}\n'
1509
- f' lr_range=({self.lrmin}, {self.lrmax})\n')
2041
+ f' lr_range=({self.step_min}, {self.step_max})')
1510
2042
 
1511
2043
  def _jax_update(self, loss):
1512
2044
  optimizer = self.optimizer
1513
2045
  projection = self.plan.projection
1514
- beta, c, lrmax, lrmin = self.beta, self.c, self.lrmax, self.lrmin
1515
-
1516
- # continue line search if Armijo condition not satisfied and learning
1517
- # rate can be further reduced
1518
- def _jax_wrapped_line_search_armijo_check(val):
1519
- (_, old_f, _, old_norm_g2, _), (_, new_f, lr, _), _, _ = val
1520
- return jnp.logical_and(
1521
- new_f >= old_f - c * lr * old_norm_g2,
1522
- lr >= lrmin / beta)
1523
-
1524
- def _jax_wrapped_line_search_iteration(val):
1525
- old, new, best, aux = val
1526
- old_x, _, old_g, _, old_state = old
1527
- _, _, lr, iters = new
1528
- _, best_f, _, _ = best
1529
- key, hyperparams, *other = aux
1530
-
1531
- # anneal learning rate and apply a gradient step
1532
- new_lr = beta * lr
1533
- old_state.hyperparams['learning_rate'] = new_lr
1534
- updates, new_state = optimizer.update(old_g, old_state)
1535
- new_x = optax.apply_updates(old_x, updates)
1536
- new_x, _ = projection(new_x, hyperparams)
1537
-
1538
- # evaluate new loss and record best so far
1539
- new_f, _ = loss(key, new_x, hyperparams, *other)
1540
- new = (new_x, new_f, new_lr, iters + 1)
1541
- best = jax.lax.cond(
1542
- new_f < best_f,
1543
- lambda: (new_x, new_f, new_lr, new_state),
1544
- lambda: best
1545
- )
1546
- return old, new, best, aux
2046
+ decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
2047
+
2048
+ # initialize the line search routine
2049
+ @jax.jit
2050
+ def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
2051
+ subs, model_params):
2052
+ (f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
2053
+ key, policy_params, hyperparams, subs, model_params)
2054
+ gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
2055
+ gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
2056
+ log['grad'] = grad
2057
+ return f, grad, gnorm2, log
1547
2058
 
2059
+ # compute the next trial solution
2060
+ @jax.jit
2061
+ def _jax_wrapped_line_search_trial(
2062
+ step, grad, key, params, hparams, subs, mparams, state):
2063
+ state.hyperparams['learning_rate'] = step
2064
+ updates, new_state = optimizer.update(grad, state)
2065
+ new_params = optax.apply_updates(params, updates)
2066
+ new_params, _ = projection(new_params, hparams)
2067
+ f_step, _ = loss(key, new_params, hparams, subs, mparams)
2068
+ return f_step, new_params, new_state
2069
+
2070
+ # main iteration of line search
1548
2071
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1549
- subs, model_params, opt_state):
1550
-
1551
- # calculate initial loss value, gradient and squared norm
1552
- old_x = policy_params
1553
- loss_and_grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1554
- (old_f, log), old_g = loss_and_grad_fn(
1555
- key, old_x, hyperparams, subs, model_params)
1556
- old_norm_g2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), old_g)
1557
- old_norm_g2 = jax.tree_util.tree_reduce(jnp.add, old_norm_g2)
1558
- log['grad'] = old_g
2072
+ subs, model_params, opt_state, opt_aux):
1559
2073
 
1560
- # initialize learning rate to maximum
1561
- new_lr = lrmax / beta
1562
- old = (old_x, old_f, old_g, old_norm_g2, opt_state)
1563
- new = (old_x, old_f, new_lr, 0)
1564
- best = (old_x, jnp.inf, jnp.nan, opt_state)
1565
- aux = (key, hyperparams, subs, model_params)
2074
+ # initialize the line search
2075
+ f, grad, gnorm2, log = _jax_wrapped_line_search_init(
2076
+ key, policy_params, hyperparams, subs, model_params)
1566
2077
 
1567
- # do a single line search step with the initial learning rate
1568
- init_val = (old, new, best, aux)
1569
- init_val = _jax_wrapped_line_search_iteration(init_val)
2078
+ # continue to reduce the learning rate until the Armijo condition holds
2079
+ trials = 0
2080
+ step = lrmax / decay
2081
+ f_step = np.inf
2082
+ best_f, best_step, best_params, best_state = np.inf, None, None, None
2083
+ while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2084
+ or not trials:
2085
+ trials += 1
2086
+ step *= decay
2087
+ f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2088
+ step, grad, key, policy_params, hyperparams, subs,
2089
+ model_params, opt_state)
2090
+ if f_step < best_f:
2091
+ best_f, best_step, best_params, best_state = \
2092
+ f_step, step, new_params, new_state
1570
2093
 
1571
- # continue to anneal the learning rate until Armijo condition holds
1572
- # or the learning rate becomes too small, then use the best parameter
1573
- _, (*_, iters), (best_params, _, best_lr, best_state), _ = \
1574
- jax.lax.while_loop(
1575
- cond_fun=_jax_wrapped_line_search_armijo_check,
1576
- body_fun=_jax_wrapped_line_search_iteration,
1577
- init_val=init_val
1578
- )
1579
- best_state.hyperparams['learning_rate'] = best_lr
1580
2094
  log['updates'] = None
1581
- log['line_search_iters'] = iters
1582
- log['learning_rate'] = best_lr
1583
- return best_params, True, best_state, log
2095
+ log['line_search_iters'] = trials
2096
+ log['learning_rate'] = best_step
2097
+ return best_params, True, best_state, best_step, best_f, log
1584
2098
 
1585
2099
  return _jax_wrapped_plan_update
1586
2100
 
1587
-
2101
+
2102
+ # ***********************************************************************
2103
+ # ALL VERSIONS OF RISK FUNCTIONS
2104
+ #
2105
+ # Based on the original paper "A Distributional Framework for Risk-Sensitive
2106
+ # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
2107
+ #
2108
+ # Original risk functions:
2109
+ # - entropic utility
2110
+ # - mean-variance approximation
2111
+ # - conditional value at risk with straight-through gradient trick
2112
+ #
2113
+ # ***********************************************************************
2114
+
2115
+
2116
+ @jax.jit
2117
+ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2118
+ return (-1.0 / beta) * jax.scipy.special.logsumexp(
2119
+ -beta * returns, b=1.0 / returns.size)
2120
+
2121
+
2122
+ @jax.jit
2123
+ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2124
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
2125
+
2126
+
2127
+ @jax.jit
2128
+ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2129
+ alpha_mask = jax.lax.stop_gradient(
2130
+ returns <= jnp.percentile(returns, q=100 * alpha))
2131
+ return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2132
+
2133
+
2134
+ # ***********************************************************************
2135
+ # ALL VERSIONS OF CONTROLLERS
2136
+ #
2137
+ # - offline controller is the straight-line planner
2138
+ # - online controller is the replanning mode
2139
+ #
2140
+ # ***********************************************************************
2141
+
1588
2142
  class JaxOfflineController(BaseAgent):
1589
2143
  '''A container class for a Jax policy trained offline.'''
2144
+
1590
2145
  use_tensor_obs = True
1591
2146
 
1592
- def __init__(self, planner: JaxBackpropPlanner, key: random.PRNGKey,
1593
- eval_hyperparams: Dict[str, object]=None,
1594
- params: Dict[str, object]=None,
2147
+ def __init__(self, planner: JaxBackpropPlanner,
2148
+ key: Optional[random.PRNGKey]=None,
2149
+ eval_hyperparams: Optional[Dict[str, Any]]=None,
2150
+ params: Optional[Pytree]=None,
1595
2151
  train_on_reset: bool=False,
1596
2152
  **train_kwargs) -> None:
1597
2153
  '''Creates a new JAX offline control policy that is trained once, then
1598
2154
  deployed later.
1599
2155
 
1600
2156
  :param planner: underlying planning algorithm for optimizing actions
1601
- :param key: the RNG key to seed randomness
2157
+ :param key: the RNG key to seed randomness (derives from clock if not
2158
+ provided)
1602
2159
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
1603
2160
  or whenever sample_action is called
1604
2161
  :param params: use the specified policy parameters instead of calling
@@ -1608,6 +2165,8 @@ class JaxOfflineController(BaseAgent):
1608
2165
  for optimization
1609
2166
  '''
1610
2167
  self.planner = planner
2168
+ if key is None:
2169
+ key = random.PRNGKey(round(time.time() * 1000))
1611
2170
  self.key = key
1612
2171
  self.eval_hyperparams = eval_hyperparams
1613
2172
  self.train_on_reset = train_on_reset
@@ -1616,60 +2175,72 @@ class JaxOfflineController(BaseAgent):
1616
2175
 
1617
2176
  self.step = 0
1618
2177
  if not self.train_on_reset and not self.params_given:
1619
- params = self.planner.optimize(key=self.key, **self.train_kwargs)
2178
+ callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2179
+ params = callback['best_params']
1620
2180
  self.params = params
1621
2181
 
1622
- def sample_action(self, state):
2182
+ def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
1623
2183
  self.key, subkey = random.split(self.key)
1624
2184
  actions = self.planner.get_action(
1625
2185
  subkey, self.params, self.step, state, self.eval_hyperparams)
1626
2186
  self.step += 1
1627
2187
  return actions
1628
2188
 
1629
- def reset(self):
2189
+ def reset(self) -> None:
1630
2190
  self.step = 0
1631
2191
  if self.train_on_reset and not self.params_given:
1632
- self.params = self.planner.optimize(key=self.key, **self.train_kwargs)
2192
+ callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2193
+ self.params = callback['best_params']
1633
2194
 
1634
2195
 
1635
2196
  class JaxOnlineController(BaseAgent):
1636
2197
  '''A container class for a Jax controller continuously updated using state
1637
2198
  feedback.'''
2199
+
1638
2200
  use_tensor_obs = True
1639
2201
 
1640
- def __init__(self, planner: JaxBackpropPlanner, key: random.PRNGKey,
1641
- eval_hyperparams: Dict=None, warm_start: bool=True,
2202
+ def __init__(self, planner: JaxBackpropPlanner,
2203
+ key: Optional[random.PRNGKey]=None,
2204
+ eval_hyperparams: Optional[Dict[str, Any]]=None,
2205
+ warm_start: bool=True,
1642
2206
  **train_kwargs) -> None:
1643
2207
  '''Creates a new JAX control policy that is trained online in a closed-
1644
2208
  loop fashion.
1645
2209
 
1646
2210
  :param planner: underlying planning algorithm for optimizing actions
1647
- :param key: the RNG key to seed randomness
2211
+ :param key: the RNG key to seed randomness (derives from clock if not
2212
+ provided)
1648
2213
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
1649
2214
  or whenever sample_action is called
2215
+ :param warm_start: whether to use the previous decision epoch final
2216
+ policy parameters to warm the next decision epoch
1650
2217
  :param **train_kwargs: any keyword arguments to be passed to the planner
1651
2218
  for optimization
1652
2219
  '''
1653
2220
  self.planner = planner
2221
+ if key is None:
2222
+ key = random.PRNGKey(round(time.time() * 1000))
1654
2223
  self.key = key
1655
2224
  self.eval_hyperparams = eval_hyperparams
1656
2225
  self.warm_start = warm_start
1657
2226
  self.train_kwargs = train_kwargs
1658
2227
  self.reset()
1659
2228
 
1660
- def sample_action(self, state):
2229
+ def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
1661
2230
  planner = self.planner
1662
- params = planner.optimize(
2231
+ callback = planner.optimize(
1663
2232
  key=self.key,
1664
2233
  guess=self.guess,
1665
2234
  subs=state,
1666
2235
  **self.train_kwargs)
2236
+ params = callback['best_params']
1667
2237
  self.key, subkey = random.split(self.key)
1668
- actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2238
+ actions = planner.get_action(
2239
+ subkey, params, 0, state, self.eval_hyperparams)
1669
2240
  if self.warm_start:
1670
2241
  self.guess = planner.plan.guess_next_epoch(params)
1671
2242
  return actions
1672
2243
 
1673
- def reset(self):
2244
+ def reset(self) -> None:
1674
2245
  self.guess = None
1675
2246