pyRDDLGym-jax 0.5__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +336 -472
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +392 -567
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/METADATA +167 -111
  36. pyRDDLGym_jax-1.1.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/top_level.txt +0 -0
@@ -31,22 +31,24 @@ from pyRDDLGym.core.policy import BaseAgent
31
31
  from pyRDDLGym_jax import __version__
32
32
  from pyRDDLGym_jax.core import logic
33
33
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
34
- from pyRDDLGym_jax.core.logic import FuzzyLogic
34
+ from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
35
35
 
36
- # try to import matplotlib, if failed then skip plotting
36
+ # try to load the dash board
37
37
  try:
38
- import matplotlib.pyplot as plt
38
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
39
39
  except Exception:
40
- raise_warning('failed to import matplotlib: '
41
- 'plotting functionality will be disabled.', 'red')
40
+ raise_warning('Failed to load the dashboard visualization tool: '
41
+ 'please make sure you have installed the required packages.',
42
+ 'red')
42
43
  traceback.print_exc()
43
- plt = None
44
-
44
+ JaxPlannerDashboard = None
45
+
45
46
  Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
47
  Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
48
  Kwargs = Dict[str, Any]
48
49
  Pytree = Any
49
50
 
51
+
50
52
  # ***********************************************************************
51
53
  # CONFIG FILE MANAGEMENT
52
54
  #
@@ -63,9 +65,8 @@ def _parse_config_file(path: str):
63
65
  config = configparser.RawConfigParser()
64
66
  config.optionxform = str
65
67
  config.read(path)
66
- args = {k: literal_eval(v)
67
- for section in config.sections()
68
- for (k, v) in config.items(section)}
68
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
69
+ for section in config.sections()}
69
70
  return config, args
70
71
 
71
72
 
@@ -73,9 +74,8 @@ def _parse_config_string(value: str):
73
74
  config = configparser.RawConfigParser()
74
75
  config.optionxform = str
75
76
  config.read_string(value)
76
- args = {k: literal_eval(v)
77
- for section in config.sections()
78
- for (k, v) in config.items(section)}
77
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
78
+ for section in config.sections()}
79
79
  return config, args
80
80
 
81
81
 
@@ -88,28 +88,32 @@ def _getattr_any(packages, item):
88
88
 
89
89
 
90
90
  def _load_config(config, args):
91
- model_args = {k: args[k] for (k, _) in config.items('Model')}
92
- planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
93
- train_args = {k: args[k] for (k, _) in config.items('Training')}
91
+ model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
92
+ planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
93
+ train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
94
94
 
95
95
  # read the model settings
96
96
  logic_name = model_args.get('logic', 'FuzzyLogic')
97
97
  logic_kwargs = model_args.get('logic_kwargs', {})
98
- tnorm_name = model_args.get('tnorm', 'ProductTNorm')
99
- tnorm_kwargs = model_args.get('tnorm_kwargs', {})
100
- comp_name = model_args.get('complement', 'StandardComplement')
101
- comp_kwargs = model_args.get('complement_kwargs', {})
102
- compare_name = model_args.get('comparison', 'SigmoidComparison')
103
- compare_kwargs = model_args.get('comparison_kwargs', {})
104
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
- sampling_kwargs = model_args.get('sampling_kwargs', {})
106
- rounding_name = model_args.get('rounding', 'SoftRounding')
107
- rounding_kwargs = model_args.get('rounding_kwargs', {})
108
- logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
109
- logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
110
- logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
111
- logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
112
- logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
98
+ if logic_name == 'FuzzyLogic':
99
+ tnorm_name = model_args.get('tnorm', 'ProductTNorm')
100
+ tnorm_kwargs = model_args.get('tnorm_kwargs', {})
101
+ comp_name = model_args.get('complement', 'StandardComplement')
102
+ comp_kwargs = model_args.get('complement_kwargs', {})
103
+ compare_name = model_args.get('comparison', 'SigmoidComparison')
104
+ compare_kwargs = model_args.get('comparison_kwargs', {})
105
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
106
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
107
+ rounding_name = model_args.get('rounding', 'SoftRounding')
108
+ rounding_kwargs = model_args.get('rounding_kwargs', {})
109
+ control_name = model_args.get('control', 'SoftControlFlow')
110
+ control_kwargs = model_args.get('control_kwargs', {})
111
+ logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
112
+ logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
113
+ logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
114
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
115
+ logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
116
+ logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
113
117
 
114
118
  # read the policy settings
115
119
  plan_method = planner_args.pop('method')
@@ -165,6 +169,13 @@ def _load_config(config, args):
165
169
  if planner_key is not None:
166
170
  train_args['key'] = random.PRNGKey(planner_key)
167
171
 
172
+ # dashboard
173
+ dashboard_key = train_args.get('dashboard', None)
174
+ if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
175
+ train_args['dashboard'] = JaxPlannerDashboard()
176
+ elif dashboard_key is not None:
177
+ del train_args['dashboard']
178
+
168
179
  # optimize call stopping rule
169
180
  stopping_rule = train_args.get('stopping_rule', None)
170
181
  if stopping_rule is not None:
@@ -186,6 +197,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
186
197
  config, args = _parse_config_string(value)
187
198
  return _load_config(config, args)
188
199
 
200
+
189
201
  # ***********************************************************************
190
202
  # MODEL RELAXATIONS
191
203
  #
@@ -202,7 +214,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
202
214
  '''
203
215
 
204
216
  def __init__(self, *args,
205
- logic: FuzzyLogic=FuzzyLogic(),
217
+ logic: Logic=FuzzyLogic(),
206
218
  cpfs_without_grad: Optional[Set[str]]=None,
207
219
  **kwargs) -> None:
208
220
  '''Creates a new RDDL to Jax compiler, where operations that are not
@@ -237,57 +249,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
237
249
 
238
250
  # overwrite basic operations with fuzzy ones
239
251
  self.RELATIONAL_OPS = {
240
- '>=': logic.greater_equal(),
241
- '<=': logic.less_equal(),
242
- '<': logic.less(),
243
- '>': logic.greater(),
244
- '==': logic.equal(),
245
- '~=': logic.not_equal()
252
+ '>=': logic.greater_equal,
253
+ '<=': logic.less_equal,
254
+ '<': logic.less,
255
+ '>': logic.greater,
256
+ '==': logic.equal,
257
+ '~=': logic.not_equal
246
258
  }
247
- self.LOGICAL_NOT = logic.logical_not()
259
+ self.LOGICAL_NOT = logic.logical_not
248
260
  self.LOGICAL_OPS = {
249
- '^': logic.logical_and(),
250
- '&': logic.logical_and(),
251
- '|': logic.logical_or(),
252
- '~': logic.xor(),
253
- '=>': logic.implies(),
254
- '<=>': logic.equiv()
261
+ '^': logic.logical_and,
262
+ '&': logic.logical_and,
263
+ '|': logic.logical_or,
264
+ '~': logic.xor,
265
+ '=>': logic.implies,
266
+ '<=>': logic.equiv
255
267
  }
256
- self.AGGREGATION_OPS['forall'] = logic.forall()
257
- self.AGGREGATION_OPS['exists'] = logic.exists()
258
- self.AGGREGATION_OPS['argmin'] = logic.argmin()
259
- self.AGGREGATION_OPS['argmax'] = logic.argmax()
260
- self.KNOWN_UNARY['sgn'] = logic.sgn()
261
- self.KNOWN_UNARY['floor'] = logic.floor()
262
- self.KNOWN_UNARY['ceil'] = logic.ceil()
263
- self.KNOWN_UNARY['round'] = logic.round()
264
- self.KNOWN_UNARY['sqrt'] = logic.sqrt()
265
- self.KNOWN_BINARY['div'] = logic.div()
266
- self.KNOWN_BINARY['mod'] = logic.mod()
267
- self.KNOWN_BINARY['fmod'] = logic.mod()
268
- self.IF_HELPER = logic.control_if()
269
- self.SWITCH_HELPER = logic.control_switch()
270
- self.BERNOULLI_HELPER = logic.bernoulli()
271
- self.DISCRETE_HELPER = logic.discrete()
272
- self.POISSON_HELPER = logic.poisson()
273
- self.GEOMETRIC_HELPER = logic.geometric()
274
-
275
- def _jax_stop_grad(self, jax_expr):
276
-
268
+ self.AGGREGATION_OPS['forall'] = logic.forall
269
+ self.AGGREGATION_OPS['exists'] = logic.exists
270
+ self.AGGREGATION_OPS['argmin'] = logic.argmin
271
+ self.AGGREGATION_OPS['argmax'] = logic.argmax
272
+ self.KNOWN_UNARY['sgn'] = logic.sgn
273
+ self.KNOWN_UNARY['floor'] = logic.floor
274
+ self.KNOWN_UNARY['ceil'] = logic.ceil
275
+ self.KNOWN_UNARY['round'] = logic.round
276
+ self.KNOWN_UNARY['sqrt'] = logic.sqrt
277
+ self.KNOWN_BINARY['div'] = logic.div
278
+ self.KNOWN_BINARY['mod'] = logic.mod
279
+ self.KNOWN_BINARY['fmod'] = logic.mod
280
+ self.IF_HELPER = logic.control_if
281
+ self.SWITCH_HELPER = logic.control_switch
282
+ self.BERNOULLI_HELPER = logic.bernoulli
283
+ self.DISCRETE_HELPER = logic.discrete
284
+ self.POISSON_HELPER = logic.poisson
285
+ self.GEOMETRIC_HELPER = logic.geometric
286
+
287
+ def _jax_stop_grad(self, jax_expr):
277
288
  def _jax_wrapped_stop_grad(x, params, key):
278
- sample, key, error = jax_expr(x, params, key)
289
+ sample, key, error, params = jax_expr(x, params, key)
279
290
  sample = jax.lax.stop_gradient(sample)
280
- return sample, key, error
281
-
291
+ return sample, key, error, params
282
292
  return _jax_wrapped_stop_grad
283
293
 
284
- def _compile_cpfs(self, info):
294
+ def _compile_cpfs(self, init_params):
285
295
  cpfs_cast = set()
286
296
  jax_cpfs = {}
287
297
  for (_, cpfs) in self.levels.items():
288
298
  for cpf in cpfs:
289
299
  _, expr = self.rddl.cpfs[cpf]
290
- jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
300
+ jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
291
301
  if self.rddl.variable_ranges[cpf] != 'real':
292
302
  cpfs_cast.add(cpf)
293
303
  if cpf in self.cpfs_without_grad:
@@ -298,17 +308,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
298
308
  f'{cpfs_cast} be cast to float.')
299
309
  if self.cpfs_without_grad:
300
310
  raise_warning(f'User requested that gradients not flow '
301
- f'through CPFs {self.cpfs_without_grad}.')
311
+ f'through CPFs {self.cpfs_without_grad}.')
312
+
302
313
  return jax_cpfs
303
314
 
304
- def _jax_kron(self, expr, info):
305
- if self.logic.verbose:
306
- raise_warning('JAX gradient compiler ignores KronDelta '
307
- 'during compilation.')
315
+ def _jax_kron(self, expr, init_params):
308
316
  arg, = expr.args
309
- arg = self._jax(arg, info)
317
+ arg = self._jax(arg, init_params)
310
318
  return arg
311
319
 
320
+
312
321
  # ***********************************************************************
313
322
  # ALL VERSIONS OF JAX PLANS
314
323
  #
@@ -329,7 +338,7 @@ class JaxPlan:
329
338
  self.bounds = None
330
339
 
331
340
  def summarize_hyperparameters(self) -> None:
332
- pass
341
+ print(self.__str__())
333
342
 
334
343
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
335
344
  _bounds: Bounds,
@@ -455,21 +464,21 @@ class JaxStraightLinePlan(JaxPlan):
455
464
  self._wrap_softmax = wrap_softmax
456
465
  self._use_new_projection = use_new_projection
457
466
  self._max_constraint_iter = max_constraint_iter
458
-
459
- def summarize_hyperparameters(self) -> None:
467
+
468
+ def __str__(self) -> str:
460
469
  bounds = '\n '.join(
461
470
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
462
- print(f'policy hyper-parameters:\n'
463
- f' initializer ={self._initializer_base}\n'
464
- f'constraint-sat strategy (simple):\n'
465
- f' parsed_action_bounds =\n {bounds}\n'
466
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
467
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
468
- f' wrap_non_bool ={self._wrap_non_bool}\n'
469
- f'constraint-sat strategy (complex):\n'
470
- f' wrap_softmax ={self._wrap_softmax}\n'
471
- f' use_new_projection ={self._use_new_projection}\n'
472
- f' max_projection_iters ={self._max_constraint_iter}')
471
+ return (f'policy hyper-parameters:\n'
472
+ f' initializer ={self._initializer_base}\n'
473
+ f'constraint-sat strategy (simple):\n'
474
+ f' parsed_action_bounds =\n {bounds}\n'
475
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
476
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
477
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
478
+ f'constraint-sat strategy (complex):\n'
479
+ f' wrap_softmax ={self._wrap_softmax}\n'
480
+ f' use_new_projection ={self._use_new_projection}\n'
481
+ f' max_projection_iters ={self._max_constraint_iter}')
473
482
 
474
483
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
475
484
  _bounds: Bounds,
@@ -820,21 +829,21 @@ class JaxDeepReactivePolicy(JaxPlan):
820
829
  normalizer_kwargs = {'create_offset': True, 'create_scale': True}
821
830
  self._normalizer_kwargs = normalizer_kwargs
822
831
  self._wrap_non_bool = wrap_non_bool
823
-
824
- def summarize_hyperparameters(self) -> None:
832
+
833
+ def __str__(self) -> str:
825
834
  bounds = '\n '.join(
826
835
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
827
- print(f'policy hyper-parameters:\n'
828
- f' topology ={self._topology}\n'
829
- f' activation_fn ={self._activations[0].__name__}\n'
830
- f' initializer ={type(self._initializer_base).__name__}\n'
831
- f' apply_input_norm ={self._normalize}\n'
832
- f' input_norm_layerwise={self._normalize_per_layer}\n'
833
- f' input_norm_args ={self._normalizer_kwargs}\n'
834
- f'constraint-sat strategy:\n'
835
- f' parsed_action_bounds=\n {bounds}\n'
836
- f' wrap_non_bool ={self._wrap_non_bool}')
837
-
836
+ return (f'policy hyper-parameters:\n'
837
+ f' topology ={self._topology}\n'
838
+ f' activation_fn ={self._activations[0].__name__}\n'
839
+ f' initializer ={type(self._initializer_base).__name__}\n'
840
+ f' apply_input_norm ={self._normalize}\n'
841
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
842
+ f' input_norm_args ={self._normalizer_kwargs}\n'
843
+ f'constraint-sat strategy:\n'
844
+ f' parsed_action_bounds=\n {bounds}\n'
845
+ f' wrap_non_bool ={self._wrap_non_bool}')
846
+
838
847
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
839
848
  _bounds: Bounds,
840
849
  horizon: int) -> None:
@@ -1057,6 +1066,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1057
1066
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1058
1067
  return params
1059
1068
 
1069
+
1060
1070
  # ***********************************************************************
1061
1071
  # ALL VERSIONS OF JAX PLANNER
1062
1072
  #
@@ -1084,113 +1094,6 @@ class RollingMean:
1084
1094
  return self._total / len(memory)
1085
1095
 
1086
1096
 
1087
- class JaxPlannerPlot:
1088
- '''Supports plotting and visualization of a JAX policy in real time.'''
1089
-
1090
- def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1091
- show_violin: bool=True, show_action: bool=True) -> None:
1092
- '''Creates a new planner visualizer.
1093
-
1094
- :param rddl: the planning model to optimize
1095
- :param horizon: the lookahead or planning horizon
1096
- :param show_violin: whether to show the distribution of batch losses
1097
- :param show_action: whether to show heatmaps of the action fluents
1098
- '''
1099
- num_plots = 1
1100
- if show_violin:
1101
- num_plots += 1
1102
- if show_action:
1103
- num_plots += len(rddl.action_fluents)
1104
- self._fig, axes = plt.subplots(num_plots)
1105
- if num_plots == 1:
1106
- axes = [axes]
1107
-
1108
- # prepare the loss plot
1109
- self._loss_ax = axes[0]
1110
- self._loss_ax.autoscale(enable=True)
1111
- self._loss_ax.set_xlabel('training time')
1112
- self._loss_ax.set_ylabel('loss value')
1113
- self._loss_plot = self._loss_ax.plot(
1114
- [], [], linestyle=':', marker='o', markersize=2)[0]
1115
- self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1116
-
1117
- # prepare the violin plot
1118
- if show_violin:
1119
- self._hist_ax = axes[1]
1120
- else:
1121
- self._hist_ax = None
1122
-
1123
- # prepare the action plots
1124
- if show_action:
1125
- self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1126
- for (idx, name) in enumerate(rddl.action_fluents)}
1127
- self._action_plots = {}
1128
- for name in rddl.action_fluents:
1129
- ax = self._action_ax[name]
1130
- if rddl.variable_ranges[name] == 'bool':
1131
- vmin, vmax = 0.0, 1.0
1132
- else:
1133
- vmin, vmax = None, None
1134
- action_dim = 1
1135
- for dim in rddl.object_counts(rddl.variable_params[name]):
1136
- action_dim *= dim
1137
- action_plot = ax.pcolormesh(
1138
- np.zeros((action_dim, horizon)),
1139
- cmap='seismic', vmin=vmin, vmax=vmax)
1140
- ax.set_aspect('auto')
1141
- ax.set_xlabel('decision epoch')
1142
- ax.set_ylabel(name)
1143
- plt.colorbar(action_plot, ax=ax)
1144
- self._action_plots[name] = action_plot
1145
- self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1146
- for (name, ax) in self._action_ax.items()}
1147
- else:
1148
- self._action_ax = None
1149
- self._action_plots = None
1150
- self._action_back = None
1151
-
1152
- plt.tight_layout()
1153
- plt.show(block=False)
1154
-
1155
- def redraw(self, xticks, losses, actions, returns) -> None:
1156
-
1157
- # draw the loss curve
1158
- self._fig.canvas.restore_region(self._loss_back)
1159
- self._loss_plot.set_xdata(xticks)
1160
- self._loss_plot.set_ydata(losses)
1161
- self._loss_ax.set_xlim([0, len(xticks)])
1162
- self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
1163
- self._loss_ax.draw_artist(self._loss_plot)
1164
- self._fig.canvas.blit(self._loss_ax.bbox)
1165
-
1166
- # draw the violin plot
1167
- if self._hist_ax is not None:
1168
- self._hist_ax.clear()
1169
- self._hist_ax.set_xlabel('loss value')
1170
- self._hist_ax.set_ylabel('density')
1171
- self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1172
-
1173
- # draw the actions
1174
- if self._action_ax is not None:
1175
- for (name, values) in actions.items():
1176
- values = np.mean(values, axis=0, dtype=float)
1177
- values = np.reshape(values, newshape=(values.shape[0], -1)).T
1178
- self._fig.canvas.restore_region(self._action_back[name])
1179
- self._action_plots[name].set_array(values)
1180
- self._action_ax[name].draw_artist(self._action_plots[name])
1181
- self._fig.canvas.blit(self._action_ax[name].bbox)
1182
- self._action_plots[name].set_clim([np.min(values), np.max(values)])
1183
-
1184
- self._fig.canvas.draw()
1185
- self._fig.canvas.flush_events()
1186
-
1187
- def close(self) -> None:
1188
- plt.close(self._fig)
1189
- del self._loss_ax, self._hist_ax, self._action_ax, \
1190
- self._loss_plot, self._action_plots, self._fig, \
1191
- self._loss_back, self._action_back
1192
-
1193
-
1194
1097
  class JaxPlannerStatus(Enum):
1195
1098
  '''Represents the status of a policy update from the JAX planner,
1196
1099
  including whether the update resulted in nan gradient,
@@ -1198,14 +1101,15 @@ class JaxPlannerStatus(Enum):
1198
1101
  can be used to monitor and act based on the planner's progress.'''
1199
1102
 
1200
1103
  NORMAL = 0
1201
- NO_PROGRESS = 1
1202
- PRECONDITION_POSSIBLY_UNSATISFIED = 2
1203
- INVALID_GRADIENT = 3
1204
- TIME_BUDGET_REACHED = 4
1205
- ITER_BUDGET_REACHED = 5
1104
+ STOPPING_RULE_REACHED = 1
1105
+ NO_PROGRESS = 2
1106
+ PRECONDITION_POSSIBLY_UNSATISFIED = 3
1107
+ INVALID_GRADIENT = 4
1108
+ TIME_BUDGET_REACHED = 5
1109
+ ITER_BUDGET_REACHED = 6
1206
1110
 
1207
- def is_failure(self) -> bool:
1208
- return self.value >= 3
1111
+ def is_terminal(self) -> bool:
1112
+ return self.value == 1 or self.value >= 4
1209
1113
 
1210
1114
 
1211
1115
  class JaxPlannerStoppingRule:
@@ -1255,15 +1159,16 @@ class JaxBackpropPlanner:
1255
1159
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1256
1160
  optimizer_kwargs: Optional[Kwargs]=None,
1257
1161
  clip_grad: Optional[float]=None,
1258
- noise_grad_eta: float=0.0,
1259
- noise_grad_gamma: float=1.0,
1260
- logic: FuzzyLogic=FuzzyLogic(),
1162
+ line_search_kwargs: Optional[Kwargs]=None,
1163
+ noise_kwargs: Optional[Kwargs]=None,
1164
+ logic: Logic=FuzzyLogic(),
1261
1165
  use_symlog_reward: bool=False,
1262
1166
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1263
1167
  utility_kwargs: Optional[Kwargs]=None,
1264
1168
  cpfs_without_grad: Optional[Set[str]]=None,
1265
1169
  compile_non_fluent_exact: bool=True,
1266
- logger: Optional[Logger]=None) -> None:
1170
+ logger: Optional[Logger]=None,
1171
+ dashboard_viz: Optional[Any]=None) -> None:
1267
1172
  '''Creates a new gradient-based algorithm for optimizing action sequences
1268
1173
  (plan) in the given RDDL. Some operations will be converted to their
1269
1174
  differentiable counterparts; the specific operations can be customized
@@ -1283,9 +1188,10 @@ class JaxBackpropPlanner:
1283
1188
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
1284
1189
  factory (e.g. which parameters are controllable externally)
1285
1190
  :param clip_grad: maximum magnitude of gradient updates
1286
- :param noise_grad_eta: scale of the gradient noise variance
1287
- :param noise_grad_gamma: decay rate of the gradient noise variance
1288
- :param logic: a subclass of FuzzyLogic for mapping exact mathematical
1191
+ :param line_search_kwargs: parameters to pass to optional line search
1192
+ method to scale learning rate
1193
+ :param noise_kwargs: parameters of optional gradient noise
1194
+ :param logic: a subclass of Logic for mapping exact mathematical
1289
1195
  operations to their differentiable counterparts
1290
1196
  :param use_symlog_reward: whether to use the symlog transform on the
1291
1197
  reward as a form of normalization
@@ -1300,6 +1206,8 @@ class JaxBackpropPlanner:
1300
1206
  :param compile_non_fluent_exact: whether non-fluent expressions
1301
1207
  are always compiled using exact JAX expressions
1302
1208
  :param logger: to log information about compilation to file
1209
+ :param dashboard_viz: optional visualizer object from the environment
1210
+ to pass to the dashboard to visualize the policy
1303
1211
  '''
1304
1212
  self.rddl = rddl
1305
1213
  self.plan = plan
@@ -1314,31 +1222,34 @@ class JaxBackpropPlanner:
1314
1222
  action_bounds = {}
1315
1223
  self._action_bounds = action_bounds
1316
1224
  self.use64bit = use64bit
1317
- self._optimizer_name = optimizer
1225
+ self.optimizer_name = optimizer
1318
1226
  if optimizer_kwargs is None:
1319
1227
  optimizer_kwargs = {'learning_rate': 0.1}
1320
- self._optimizer_kwargs = optimizer_kwargs
1228
+ self.optimizer_kwargs = optimizer_kwargs
1321
1229
  self.clip_grad = clip_grad
1322
- self.noise_grad_eta = noise_grad_eta
1323
- self.noise_grad_gamma = noise_grad_gamma
1230
+ self.line_search_kwargs = line_search_kwargs
1231
+ self.noise_kwargs = noise_kwargs
1324
1232
 
1325
1233
  # set optimizer
1326
1234
  try:
1327
1235
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1328
1236
  except Exception as _:
1329
1237
  raise_warning(
1330
- 'Failed to inject hyperparameters into optax optimizer, '
1238
+ f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
1331
1239
  'rolling back to safer method: please note that modification of '
1332
- 'optimizer hyperparameters will not work, and it is '
1333
- 'recommended to update optax and related packages.', 'red')
1334
- optimizer = optimizer(**optimizer_kwargs)
1335
- if clip_grad is None:
1336
- self.optimizer = optimizer
1337
- else:
1338
- self.optimizer = optax.chain(
1339
- optax.clip(clip_grad),
1340
- optimizer
1341
- )
1240
+ 'optimizer hyperparameters will not work.', 'red')
1241
+ optimizer = optimizer(**optimizer_kwargs)
1242
+
1243
+ # apply optimizer chain of transformations
1244
+ pipeline = []
1245
+ if clip_grad is not None:
1246
+ pipeline.append(optax.clip(clip_grad))
1247
+ if noise_kwargs is not None:
1248
+ pipeline.append(optax.add_noise(**noise_kwargs))
1249
+ pipeline.append(optimizer)
1250
+ if line_search_kwargs is not None:
1251
+ pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
1252
+ self.optimizer = optax.chain(*pipeline)
1342
1253
 
1343
1254
  # set utility
1344
1255
  if isinstance(utility, str):
@@ -1371,11 +1282,12 @@ class JaxBackpropPlanner:
1371
1282
  self.cpfs_without_grad = cpfs_without_grad
1372
1283
  self.compile_non_fluent_exact = compile_non_fluent_exact
1373
1284
  self.logger = logger
1285
+ self.dashboard_viz = dashboard_viz
1374
1286
 
1375
1287
  self._jax_compile_rddl()
1376
1288
  self._jax_compile_optimizer()
1377
1289
 
1378
- def _summarize_system(self) -> None:
1290
+ def summarize_system(self) -> str:
1379
1291
  try:
1380
1292
  jaxlib_version = jax._src.lib.version_str
1381
1293
  except Exception as _:
@@ -1394,36 +1306,55 @@ r"""
1394
1306
  \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1395
1307
  """
1396
1308
 
1397
- print('\n'
1398
- f'{LOGO}\n'
1399
- f'Version {__version__}\n'
1400
- f'Python {sys.version}\n'
1401
- f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1402
- f'optax {optax.__version__}, haiku {hk.__version__}, '
1403
- f'numpy {np.__version__}\n'
1404
- f'devices: {devices_short}\n')
1309
+ return ('\n'
1310
+ f'{LOGO}\n'
1311
+ f'Version {__version__}\n'
1312
+ f'Python {sys.version}\n'
1313
+ f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1314
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1315
+ f'numpy {np.__version__}\n'
1316
+ f'devices: {devices_short}\n')
1317
+
1318
+ def __str__(self) -> str:
1319
+ result = (f'objective hyper-parameters:\n'
1320
+ f' utility_fn ={self.utility.__name__}\n'
1321
+ f' utility args ={self.utility_kwargs}\n'
1322
+ f' use_symlog ={self.use_symlog_reward}\n'
1323
+ f' lookahead ={self.horizon}\n'
1324
+ f' user_action_bounds={self._action_bounds}\n'
1325
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1326
+ f' non_fluents exact ={self.compile_non_fluent_exact}\n'
1327
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1328
+ f'optimizer hyper-parameters:\n'
1329
+ f' use_64_bit ={self.use64bit}\n'
1330
+ f' optimizer ={self.optimizer_name}\n'
1331
+ f' optimizer args ={self.optimizer_kwargs}\n'
1332
+ f' clip_gradient ={self.clip_grad}\n'
1333
+ f' line_search_kwargs={self.line_search_kwargs}\n'
1334
+ f' noise_kwargs ={self.noise_kwargs}\n'
1335
+ f' batch_size_train ={self.batch_size_train}\n'
1336
+ f' batch_size_test ={self.batch_size_test}')
1337
+ result += '\n' + str(self.plan)
1338
+ result += '\n' + str(self.logic)
1339
+
1340
+ # print model relaxation information
1341
+ if not self.compiled.model_params:
1342
+ return result
1343
+ result += '\n' + ('Some RDDL operations are non-differentiable '
1344
+ 'and will be approximated as follows:' + '\n')
1345
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1346
+ for info in self.compiled.model_parameter_info().values():
1347
+ rddl_op = info['rddl_op']
1348
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1349
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1350
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1351
+ result += (f' {rddl_op}:\n'
1352
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1353
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1354
+ return result
1405
1355
 
1406
1356
  def summarize_hyperparameters(self) -> None:
1407
- print(f'objective hyper-parameters:\n'
1408
- f' utility_fn ={self.utility.__name__}\n'
1409
- f' utility args ={self.utility_kwargs}\n'
1410
- f' use_symlog ={self.use_symlog_reward}\n'
1411
- f' lookahead ={self.horizon}\n'
1412
- f' user_action_bounds={self._action_bounds}\n'
1413
- f' fuzzy logic type ={type(self.logic).__name__}\n'
1414
- f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1415
- f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1416
- f'optimizer hyper-parameters:\n'
1417
- f' use_64_bit ={self.use64bit}\n'
1418
- f' optimizer ={self._optimizer_name.__name__}\n'
1419
- f' optimizer args ={self._optimizer_kwargs}\n'
1420
- f' clip_gradient ={self.clip_grad}\n'
1421
- f' noise_grad_eta ={self.noise_grad_eta}\n'
1422
- f' noise_grad_gamma ={self.noise_grad_gamma}\n'
1423
- f' batch_size_train ={self.batch_size_train}\n'
1424
- f' batch_size_test ={self.batch_size_test}')
1425
- self.plan.summarize_hyperparameters()
1426
- self.logic.summarize_hyperparameters()
1357
+ print(self.__str__())
1427
1358
 
1428
1359
  # ===========================================================================
1429
1360
  # COMPILATION SUBROUTINES
@@ -1439,14 +1370,16 @@ r"""
1439
1370
  logger=self.logger,
1440
1371
  use64bit=self.use64bit,
1441
1372
  cpfs_without_grad=self.cpfs_without_grad,
1442
- compile_non_fluent_exact=self.compile_non_fluent_exact)
1373
+ compile_non_fluent_exact=self.compile_non_fluent_exact
1374
+ )
1443
1375
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1444
1376
 
1445
1377
  # Jax compilation of the exact RDDL for testing
1446
1378
  self.test_compiled = JaxRDDLCompiler(
1447
1379
  rddl=rddl,
1448
1380
  logger=self.logger,
1449
- use64bit=self.use64bit)
1381
+ use64bit=self.use64bit
1382
+ )
1450
1383
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1451
1384
 
1452
1385
  def _jax_compile_optimizer(self):
@@ -1462,13 +1395,15 @@ r"""
1462
1395
  train_rollouts = self.compiled.compile_rollouts(
1463
1396
  policy=self.plan.train_policy,
1464
1397
  n_steps=self.horizon,
1465
- n_batch=self.batch_size_train)
1398
+ n_batch=self.batch_size_train
1399
+ )
1466
1400
  self.train_rollouts = train_rollouts
1467
1401
 
1468
1402
  test_rollouts = self.test_compiled.compile_rollouts(
1469
1403
  policy=self.plan.test_policy,
1470
1404
  n_steps=self.horizon,
1471
- n_batch=self.batch_size_test)
1405
+ n_batch=self.batch_size_test
1406
+ )
1472
1407
  self.test_rollouts = jax.jit(test_rollouts)
1473
1408
 
1474
1409
  # initialization
@@ -1503,14 +1438,16 @@ r"""
1503
1438
  _jax_wrapped_returns = self._jax_return(use_symlog)
1504
1439
 
1505
1440
  # the loss is the average cumulative reward across all roll-outs
1506
- def _jax_wrapped_plan_loss(key, policy_params, hyperparams,
1441
+ def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
1507
1442
  subs, model_params):
1508
- log = rollouts(key, policy_params, hyperparams, subs, model_params)
1443
+ log, model_params = rollouts(
1444
+ key, policy_params, policy_hyperparams, subs, model_params)
1509
1445
  rewards = log['reward']
1510
1446
  returns = _jax_wrapped_returns(rewards)
1511
1447
  utility = utility_fn(returns, **utility_kwargs)
1512
1448
  loss = -utility
1513
- return loss, log
1449
+ aux = (log, model_params)
1450
+ return loss, aux
1514
1451
 
1515
1452
  return _jax_wrapped_plan_loss
1516
1453
 
@@ -1518,8 +1455,9 @@ r"""
1518
1455
  init = self.plan.initializer
1519
1456
  optimizer = self.optimizer
1520
1457
 
1521
- def _jax_wrapped_init_policy(key, hyperparams, subs):
1522
- policy_params = init(key, hyperparams, subs)
1458
+ # initialize both the policy and its optimizer
1459
+ def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
1460
+ policy_params = init(key, policy_hyperparams, subs)
1523
1461
  opt_state = optimizer.init(policy_params)
1524
1462
  return policy_params, opt_state, {}
1525
1463
 
@@ -1528,35 +1466,34 @@ r"""
1528
1466
  def _jax_update(self, loss):
1529
1467
  optimizer = self.optimizer
1530
1468
  projection = self.plan.projection
1531
-
1532
- # add Gaussian gradient noise per Neelakantan et al., 2016.
1533
- def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
1534
- treedef = jax.tree_util.tree_structure(grads)
1535
- keys_flat = random.split(key, num=treedef.num_leaves)
1536
- keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
1537
- new_grads = jax.tree_map(
1538
- lambda g, k: g + sigma * random.normal(
1539
- key=k, shape=g.shape, dtype=g.dtype),
1540
- grads,
1541
- keys_tree
1542
- )
1543
- return new_grads
1469
+ use_ls = self.line_search_kwargs is not None
1544
1470
 
1545
1471
  # calculate the plan gradient w.r.t. return loss and update optimizer
1546
1472
  # also perform a projection step to satisfy constraints on actions
1547
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1473
+ def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
1474
+ subs, model_params):
1475
+ return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
1476
+
1477
+ def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
1548
1478
  subs, model_params, opt_state, opt_aux):
1549
1479
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1550
- (loss_val, log), grad = grad_fn(
1551
- key, policy_params, hyperparams, subs, model_params)
1552
- sigma = opt_aux.get('noise_sigma', 0.0)
1553
- grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
1554
- updates, opt_state = optimizer.update(grad, opt_state)
1480
+ (loss_val, (log, model_params)), grad = grad_fn(
1481
+ key, policy_params, policy_hyperparams, subs, model_params)
1482
+ if use_ls:
1483
+ updates, opt_state = optimizer.update(
1484
+ grad, opt_state, params=policy_params,
1485
+ value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
1486
+ key=key, policy_hyperparams=policy_hyperparams, subs=subs,
1487
+ model_params=model_params)
1488
+ else:
1489
+ updates, opt_state = optimizer.update(
1490
+ grad, opt_state, params=policy_params)
1555
1491
  policy_params = optax.apply_updates(policy_params, updates)
1556
- policy_params, converged = projection(policy_params, hyperparams)
1492
+ policy_params, converged = projection(policy_params, policy_hyperparams)
1557
1493
  log['grad'] = grad
1558
1494
  log['updates'] = updates
1559
- return policy_params, converged, opt_state, opt_aux, loss_val, log
1495
+ return policy_params, converged, opt_state, opt_aux, \
1496
+ loss_val, log, model_params
1560
1497
 
1561
1498
  return jax.jit(_jax_wrapped_plan_update)
1562
1499
 
@@ -1583,7 +1520,6 @@ r"""
1583
1520
  for (state, next_state) in rddl.next_state.items():
1584
1521
  init_train[next_state] = init_train[state]
1585
1522
  init_test[next_state] = init_test[state]
1586
-
1587
1523
  return init_train, init_test
1588
1524
 
1589
1525
  def as_optimization_problem(
@@ -1637,38 +1573,40 @@ r"""
1637
1573
  loss_fn = self._jax_loss(self.train_rollouts)
1638
1574
 
1639
1575
  @jax.jit
1640
- def _loss_with_key(key, params_1d):
1576
+ def _loss_with_key(key, params_1d, model_params):
1641
1577
  policy_params = unravel_fn(params_1d)
1642
- loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1643
- train_subs, model_params)
1644
- return loss_val
1578
+ loss_val, (_, model_params) = loss_fn(
1579
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1580
+ return loss_val, model_params
1645
1581
 
1646
1582
  @jax.jit
1647
- def _grad_with_key(key, params_1d):
1583
+ def _grad_with_key(key, params_1d, model_params):
1648
1584
  policy_params = unravel_fn(params_1d)
1649
1585
  grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1650
- grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1651
- train_subs, model_params)
1652
- grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1653
- return grad_1d
1586
+ grad_val, (_, model_params) = grad_fn(
1587
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1588
+ grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
1589
+ return grad_val, model_params
1654
1590
 
1655
1591
  def _loss_function(params_1d):
1656
1592
  nonlocal key
1593
+ nonlocal model_params
1657
1594
  if loss_function_updates_key:
1658
1595
  key, subkey = random.split(key)
1659
1596
  else:
1660
1597
  subkey = key
1661
- loss_val = _loss_with_key(subkey, params_1d)
1598
+ loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
1662
1599
  loss_val = float(loss_val)
1663
1600
  return loss_val
1664
1601
 
1665
1602
  def _grad_function(params_1d):
1666
1603
  nonlocal key
1604
+ nonlocal model_params
1667
1605
  if grad_function_updates_key:
1668
1606
  key, subkey = random.split(key)
1669
1607
  else:
1670
1608
  subkey = key
1671
- grad = _grad_with_key(subkey, params_1d)
1609
+ grad, model_params = _grad_with_key(subkey, params_1d, model_params)
1672
1610
  grad = np.asarray(grad)
1673
1611
  return grad
1674
1612
 
@@ -1683,9 +1621,9 @@ r"""
1683
1621
 
1684
1622
  :param key: JAX PRNG key (derived from clock if not provided)
1685
1623
  :param epochs: the maximum number of steps of gradient descent
1686
- :param train_seconds: total time allocated for gradient descent
1687
- :param plot_step: frequency to plot the plan and save result to disk
1688
- :param plot_kwargs: additional arguments to pass to the plotter
1624
+ :param train_seconds: total time allocated for gradient descent
1625
+ :param dashboard: dashboard to display training results
1626
+ :param dashboard_id: experiment id for the dashboard
1689
1627
  :param model_params: optional model-parameters to override default
1690
1628
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1691
1629
  weights for sigmoid wrapping boolean actions
@@ -1721,8 +1659,8 @@ r"""
1721
1659
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1722
1660
  epochs: int=999999,
1723
1661
  train_seconds: float=120.,
1724
- plot_step: Optional[int]=None,
1725
- plot_kwargs: Optional[Kwargs]=None,
1662
+ dashboard: Optional[Any]=None,
1663
+ dashboard_id: Optional[str]=None,
1726
1664
  model_params: Optional[Dict[str, Any]]=None,
1727
1665
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1728
1666
  subs: Optional[Dict[str, Any]]=None,
@@ -1738,9 +1676,9 @@ r"""
1738
1676
 
1739
1677
  :param key: JAX PRNG key (derived from clock if not provided)
1740
1678
  :param epochs: the maximum number of steps of gradient descent
1741
- :param train_seconds: total time allocated for gradient descent
1742
- :param plot_step: frequency to plot the plan and save result to disk
1743
- :param plot_kwargs: additional arguments to pass to the plotter
1679
+ :param train_seconds: total time allocated for gradient descent
1680
+ :param dashboard: dashboard to display training results
1681
+ :param dashboard_id: experiment id for the dashboard
1744
1682
  :param model_params: optional model-parameters to override default
1745
1683
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1746
1684
  weights for sigmoid wrapping boolean actions
@@ -1760,9 +1698,14 @@ r"""
1760
1698
  start_time = time.time()
1761
1699
  elapsed_outside_loop = 0
1762
1700
 
1701
+ # ======================================================================
1702
+ # INITIALIZATION OF HYPER-PARAMETERS
1703
+ # ======================================================================
1704
+
1763
1705
  # if PRNG key is not provided
1764
1706
  if key is None:
1765
1707
  key = random.PRNGKey(round(time.time() * 1000))
1708
+ dash_key = key[1].item()
1766
1709
 
1767
1710
  # if policy_hyperparams is not provided
1768
1711
  if policy_hyperparams is None:
@@ -1789,7 +1732,7 @@ r"""
1789
1732
 
1790
1733
  # print summary of parameters:
1791
1734
  if print_summary:
1792
- self._summarize_system()
1735
+ print(self.summarize_system())
1793
1736
  self.summarize_hyperparameters()
1794
1737
  print(f'optimize() call hyper-parameters:\n'
1795
1738
  f' PRNG key ={key}\n'
@@ -1800,16 +1743,16 @@ r"""
1800
1743
  f' override_subs_dict ={subs is not None}\n'
1801
1744
  f' provide_param_guess={guess is not None}\n'
1802
1745
  f' test_rolling_window={test_rolling_window}\n'
1803
- f' plot_frequency ={plot_step}\n'
1804
- f' plot_kwargs ={plot_kwargs}\n'
1746
+ f' dashboard ={dashboard is not None}\n'
1747
+ f' dashboard_id ={dashboard_id}\n'
1805
1748
  f' print_summary ={print_summary}\n'
1806
1749
  f' print_progress ={print_progress}\n'
1807
1750
  f' stopping_rule ={stopping_rule}\n')
1808
- if self.compiled.relaxations:
1809
- print('Some RDDL operations are non-differentiable, '
1810
- 'they will be approximated as follows:')
1811
- print(self.compiled.summarize_model_relaxations())
1812
-
1751
+
1752
+ # ======================================================================
1753
+ # INITIALIZATION OF STATE AND POLICY
1754
+ # ======================================================================
1755
+
1813
1756
  # compute a batched version of the initial values
1814
1757
  if subs is None:
1815
1758
  subs = self.test_compiled.init_values
@@ -1841,6 +1784,10 @@ r"""
1841
1784
  policy_params = guess
1842
1785
  opt_state = self.optimizer.init(policy_params)
1843
1786
  opt_aux = {}
1787
+
1788
+ # ======================================================================
1789
+ # INITIALIZATION OF RUNNING STATISTICS
1790
+ # ======================================================================
1844
1791
 
1845
1792
  # initialize running statistics
1846
1793
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
@@ -1854,35 +1801,39 @@ r"""
1854
1801
  if stopping_rule is not None:
1855
1802
  stopping_rule.reset()
1856
1803
 
1857
- # initialize plot area
1858
- if plot_step is None or plot_step <= 0 or plt is None:
1859
- plot = None
1860
- else:
1861
- if plot_kwargs is None:
1862
- plot_kwargs = {}
1863
- plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1864
- xticks, loss_values = [], []
1804
+ # initialize dash board
1805
+ if dashboard is not None:
1806
+ dashboard_id = dashboard.register_experiment(
1807
+ dashboard_id, dashboard.get_planner_info(self),
1808
+ key=dash_key, viz=self.dashboard_viz)
1809
+
1810
+ # ======================================================================
1811
+ # MAIN TRAINING LOOP BEGINS
1812
+ # ======================================================================
1865
1813
 
1866
- # training loop
1867
1814
  iters = range(epochs)
1868
1815
  if print_progress:
1869
1816
  iters = tqdm(iters, total=100, position=tqdm_position)
1870
1817
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1871
1818
 
1872
1819
  for it in iters:
1873
- status = JaxPlannerStatus.NORMAL
1874
1820
 
1875
- # gradient noise schedule
1876
- noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
1877
- noise_sigma = np.sqrt(noise_var)
1878
- opt_aux['noise_sigma'] = noise_sigma
1821
+ # ==================================================================
1822
+ # NEXT GRADIENT DESCENT STEP
1823
+ # ==================================================================
1824
+
1825
+ status = JaxPlannerStatus.NORMAL
1879
1826
 
1880
1827
  # update the parameters of the plan
1881
1828
  key, subkey = random.split(key)
1882
- policy_params, converged, opt_state, opt_aux, \
1883
- train_loss, train_log = \
1829
+ (policy_params, converged, opt_state, opt_aux,
1830
+ train_loss, train_log, model_params) = \
1884
1831
  self.update(subkey, policy_params, policy_hyperparams,
1885
1832
  train_subs, model_params, opt_state, opt_aux)
1833
+
1834
+ # ==================================================================
1835
+ # STATUS CHECKS AND LOGGING
1836
+ # ==================================================================
1886
1837
 
1887
1838
  # no progress
1888
1839
  grad_norm_zero, _ = jax.tree_util.tree_flatten(
@@ -1901,12 +1852,11 @@ r"""
1901
1852
  # numerical error
1902
1853
  if not np.isfinite(train_loss):
1903
1854
  raise_warning(
1904
- f'Aborting JAX planner due to invalid train loss {train_loss}.',
1905
- 'red')
1855
+ f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
1906
1856
  status = JaxPlannerStatus.INVALID_GRADIENT
1907
1857
 
1908
1858
  # evaluate test losses and record best plan so far
1909
- test_loss, log = self.test_loss(
1859
+ test_loss, (log, model_params_test) = self.test_loss(
1910
1860
  subkey, policy_params, policy_hyperparams,
1911
1861
  test_subs, model_params_test)
1912
1862
  test_loss = rolling_test_loss.update(test_loss)
@@ -1915,32 +1865,15 @@ r"""
1915
1865
  policy_params, test_loss, train_log['grad']
1916
1866
  last_iter_improve = it
1917
1867
 
1918
- # save the plan figure
1919
- if plot is not None and it % plot_step == 0:
1920
- xticks.append(it // plot_step)
1921
- loss_values.append(test_loss.item())
1922
- action_values = {name: values
1923
- for (name, values) in log['fluents'].items()
1924
- if name in self.rddl.action_fluents}
1925
- returns = -np.sum(np.asarray(log['reward']), axis=1)
1926
- plot.redraw(xticks, loss_values, action_values, returns)
1927
-
1928
- # if the progress bar is used
1929
- elapsed = time.time() - start_time - elapsed_outside_loop
1930
- if print_progress:
1931
- iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1932
- iters.set_description(
1933
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1934
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1935
- f'{status.value} status')
1936
-
1937
1868
  # reached computation budget
1869
+ elapsed = time.time() - start_time - elapsed_outside_loop
1938
1870
  if elapsed >= train_seconds:
1939
1871
  status = JaxPlannerStatus.TIME_BUDGET_REACHED
1940
1872
  if it >= epochs - 1:
1941
1873
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1942
1874
 
1943
- # return a callback
1875
+ # build a callback
1876
+ progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1944
1877
  callback = {
1945
1878
  'status': status,
1946
1879
  'iteration': it,
@@ -1952,29 +1885,48 @@ r"""
1952
1885
  'last_iteration_improved': last_iter_improve,
1953
1886
  'grad': train_log['grad'],
1954
1887
  'best_grad': best_grad,
1955
- 'noise_sigma': noise_sigma,
1956
1888
  'updates': train_log['updates'],
1957
1889
  'elapsed_time': elapsed,
1958
1890
  'key': key,
1891
+ 'model_params': model_params,
1892
+ 'progress': progress_percent,
1893
+ 'train_log': train_log,
1959
1894
  **log
1960
1895
  }
1896
+
1897
+ # stopping condition reached
1898
+ if stopping_rule is not None and stopping_rule.monitor(callback):
1899
+ callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
1900
+
1901
+ # if the progress bar is used
1902
+ if print_progress:
1903
+ iters.n = progress_percent
1904
+ iters.set_description(
1905
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1906
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1907
+ f'{status.value} status'
1908
+ )
1909
+
1910
+ # dash-board
1911
+ if dashboard is not None:
1912
+ dashboard.update_experiment(dashboard_id, callback)
1913
+
1914
+ # yield the callback
1961
1915
  start_time_outside = time.time()
1962
1916
  yield callback
1963
1917
  elapsed_outside_loop += (time.time() - start_time_outside)
1964
1918
 
1965
1919
  # abortion check
1966
- if status.is_failure():
1967
- break
1968
-
1969
- # stopping condition reached
1970
- if stopping_rule is not None and stopping_rule.monitor(callback):
1971
- break
1972
-
1920
+ if status.is_terminal():
1921
+ break
1922
+
1923
+ # ======================================================================
1924
+ # POST-PROCESSING AND CLEANUP
1925
+ # ======================================================================
1926
+
1973
1927
  # release resources
1974
1928
  if print_progress:
1975
1929
  iters.close()
1976
- if plot is not None:
1977
- plot.close()
1978
1930
 
1979
1931
  # validate the test return
1980
1932
  if log:
@@ -2098,101 +2050,6 @@ r"""
2098
2050
  return actions
2099
2051
 
2100
2052
 
2101
- class JaxLineSearchPlanner(JaxBackpropPlanner):
2102
- '''A class for optimizing an action sequence in the given RDDL MDP using
2103
- linear search gradient descent, with the Armijo condition.'''
2104
-
2105
- def __init__(self, *args,
2106
- decay: float=0.8,
2107
- c: float=0.1,
2108
- step_max: float=1.0,
2109
- step_min: float=1e-6,
2110
- **kwargs) -> None:
2111
- '''Creates a new gradient-based algorithm for optimizing action sequences
2112
- (plan) in the given RDDL using line search. All arguments are the
2113
- same as in the parent class, except:
2114
-
2115
- :param decay: reduction factor of learning rate per line search iteration
2116
- :param c: positive coefficient in Armijo condition, should be in (0, 1)
2117
- :param step_max: initial learning rate for line search
2118
- :param step_min: minimum possible learning rate (line search halts)
2119
- '''
2120
- self.decay = decay
2121
- self.c = c
2122
- self.step_max = step_max
2123
- self.step_min = step_min
2124
- if 'clip_grad' in kwargs:
2125
- raise_warning('clip_grad parameter conflicts with '
2126
- 'line search planner and will be ignored.', 'red')
2127
- del kwargs['clip_grad']
2128
- super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
2129
-
2130
- def summarize_hyperparameters(self) -> None:
2131
- super(JaxLineSearchPlanner, self).summarize_hyperparameters()
2132
- print(f'linesearch hyper-parameters:\n'
2133
- f' decay ={self.decay}\n'
2134
- f' c ={self.c}\n'
2135
- f' lr_range=({self.step_min}, {self.step_max})')
2136
-
2137
- def _jax_update(self, loss):
2138
- optimizer = self.optimizer
2139
- projection = self.plan.projection
2140
- decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
2141
-
2142
- # initialize the line search routine
2143
- @jax.jit
2144
- def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
2145
- subs, model_params):
2146
- (f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
2147
- key, policy_params, hyperparams, subs, model_params)
2148
- gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
2149
- gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
2150
- log['grad'] = grad
2151
- return f, grad, gnorm2, log
2152
-
2153
- # compute the next trial solution
2154
- @jax.jit
2155
- def _jax_wrapped_line_search_trial(
2156
- step, grad, key, params, hparams, subs, mparams, state):
2157
- state.hyperparams['learning_rate'] = step
2158
- updates, new_state = optimizer.update(grad, state)
2159
- new_params = optax.apply_updates(params, updates)
2160
- new_params, _ = projection(new_params, hparams)
2161
- f_step, _ = loss(key, new_params, hparams, subs, mparams)
2162
- return f_step, new_params, new_state
2163
-
2164
- # main iteration of line search
2165
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
2166
- subs, model_params, opt_state, opt_aux):
2167
-
2168
- # initialize the line search
2169
- f, grad, gnorm2, log = _jax_wrapped_line_search_init(
2170
- key, policy_params, hyperparams, subs, model_params)
2171
-
2172
- # continue to reduce the learning rate until the Armijo condition holds
2173
- trials = 0
2174
- step = lrmax / decay
2175
- f_step = np.inf
2176
- best_f, best_step, best_params, best_state = np.inf, None, None, None
2177
- while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2178
- or not trials:
2179
- trials += 1
2180
- step *= decay
2181
- f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2182
- step, grad, key, policy_params, hyperparams, subs,
2183
- model_params, opt_state)
2184
- if f_step < best_f:
2185
- best_f, best_step, best_params, best_state = \
2186
- f_step, step, new_params, new_state
2187
-
2188
- log['updates'] = None
2189
- log['line_search_iters'] = trials
2190
- log['learning_rate'] = best_step
2191
- opt_aux['best_step'] = best_step
2192
- return best_params, True, best_state, opt_aux, best_f, log
2193
-
2194
- return _jax_wrapped_plan_update
2195
-
2196
2053
  # ***********************************************************************
2197
2054
  # ALL VERSIONS OF RISK FUNCTIONS
2198
2055
  #
@@ -2224,6 +2081,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2224
2081
  returns <= jnp.percentile(returns, q=100 * alpha))
2225
2082
  return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2226
2083
 
2084
+
2227
2085
  # ***********************************************************************
2228
2086
  # ALL VERSIONS OF CONTROLLERS
2229
2087
  #
@@ -2268,8 +2126,10 @@ class JaxOfflineController(BaseAgent):
2268
2126
  self.params_given = params is not None
2269
2127
 
2270
2128
  self.step = 0
2129
+ self.callback = None
2271
2130
  if not self.train_on_reset and not self.params_given:
2272
2131
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2132
+ self.callback = callback
2273
2133
  params = callback['best_params']
2274
2134
  self.params = params
2275
2135
 
@@ -2284,6 +2144,7 @@ class JaxOfflineController(BaseAgent):
2284
2144
  self.step = 0
2285
2145
  if self.train_on_reset and not self.params_given:
2286
2146
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2147
+ self.callback = callback
2287
2148
  self.params = callback['best_params']
2288
2149
 
2289
2150
 
@@ -2326,7 +2187,9 @@ class JaxOnlineController(BaseAgent):
2326
2187
  key=self.key,
2327
2188
  guess=self.guess,
2328
2189
  subs=state,
2329
- **self.train_kwargs)
2190
+ **self.train_kwargs
2191
+ )
2192
+ self.callback = callback
2330
2193
  params = callback['best_params']
2331
2194
  self.key, subkey = random.split(self.key)
2332
2195
  actions = planner.get_action(
@@ -2337,4 +2200,5 @@ class JaxOnlineController(BaseAgent):
2337
2200
 
2338
2201
  def reset(self) -> None:
2339
2202
  self.guess = None
2203
+ self.callback = None
2340
2204