pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -31,22 +31,24 @@ from pyRDDLGym.core.policy import BaseAgent
31
31
  from pyRDDLGym_jax import __version__
32
32
  from pyRDDLGym_jax.core import logic
33
33
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
34
- from pyRDDLGym_jax.core.logic import FuzzyLogic
34
+ from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
35
35
 
36
- # try to import matplotlib, if failed then skip plotting
36
+ # try to load the dash board
37
37
  try:
38
- import matplotlib.pyplot as plt
38
+ from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
39
39
  except Exception:
40
- raise_warning('failed to import matplotlib: '
41
- 'plotting functionality will be disabled.', 'red')
40
+ raise_warning('Failed to load the dashboard visualization tool: '
41
+ 'please make sure you have installed the required packages.',
42
+ 'red')
42
43
  traceback.print_exc()
43
- plt = None
44
-
44
+ JaxPlannerDashboard = None
45
+
45
46
  Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
47
  Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
48
  Kwargs = Dict[str, Any]
48
49
  Pytree = Any
49
50
 
51
+
50
52
  # ***********************************************************************
51
53
  # CONFIG FILE MANAGEMENT
52
54
  #
@@ -95,21 +97,25 @@ def _load_config(config, args):
95
97
  # read the model settings
96
98
  logic_name = model_args.get('logic', 'FuzzyLogic')
97
99
  logic_kwargs = model_args.get('logic_kwargs', {})
98
- tnorm_name = model_args.get('tnorm', 'ProductTNorm')
99
- tnorm_kwargs = model_args.get('tnorm_kwargs', {})
100
- comp_name = model_args.get('complement', 'StandardComplement')
101
- comp_kwargs = model_args.get('complement_kwargs', {})
102
- compare_name = model_args.get('comparison', 'SigmoidComparison')
103
- compare_kwargs = model_args.get('comparison_kwargs', {})
104
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
- sampling_kwargs = model_args.get('sampling_kwargs', {})
106
- rounding_name = model_args.get('rounding', 'SoftRounding')
107
- rounding_kwargs = model_args.get('rounding_kwargs', {})
108
- logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
109
- logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
110
- logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
111
- logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
112
- logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
100
+ if logic_name == 'FuzzyLogic':
101
+ tnorm_name = model_args.get('tnorm', 'ProductTNorm')
102
+ tnorm_kwargs = model_args.get('tnorm_kwargs', {})
103
+ comp_name = model_args.get('complement', 'StandardComplement')
104
+ comp_kwargs = model_args.get('complement_kwargs', {})
105
+ compare_name = model_args.get('comparison', 'SigmoidComparison')
106
+ compare_kwargs = model_args.get('comparison_kwargs', {})
107
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
108
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
109
+ rounding_name = model_args.get('rounding', 'SoftRounding')
110
+ rounding_kwargs = model_args.get('rounding_kwargs', {})
111
+ control_name = model_args.get('control', 'SoftControlFlow')
112
+ control_kwargs = model_args.get('control_kwargs', {})
113
+ logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
114
+ logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
115
+ logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
116
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
117
+ logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
118
+ logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
113
119
 
114
120
  # read the policy settings
115
121
  plan_method = planner_args.pop('method')
@@ -165,6 +171,13 @@ def _load_config(config, args):
165
171
  if planner_key is not None:
166
172
  train_args['key'] = random.PRNGKey(planner_key)
167
173
 
174
+ # dashboard
175
+ dashboard_key = train_args.get('dashboard', None)
176
+ if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
177
+ train_args['dashboard'] = JaxPlannerDashboard()
178
+ elif dashboard_key is not None:
179
+ del train_args['dashboard']
180
+
168
181
  # optimize call stopping rule
169
182
  stopping_rule = train_args.get('stopping_rule', None)
170
183
  if stopping_rule is not None:
@@ -186,6 +199,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
186
199
  config, args = _parse_config_string(value)
187
200
  return _load_config(config, args)
188
201
 
202
+
189
203
  # ***********************************************************************
190
204
  # MODEL RELAXATIONS
191
205
  #
@@ -202,7 +216,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
202
216
  '''
203
217
 
204
218
  def __init__(self, *args,
205
- logic: FuzzyLogic=FuzzyLogic(),
219
+ logic: Logic=FuzzyLogic(),
206
220
  cpfs_without_grad: Optional[Set[str]]=None,
207
221
  **kwargs) -> None:
208
222
  '''Creates a new RDDL to Jax compiler, where operations that are not
@@ -237,57 +251,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
237
251
 
238
252
  # overwrite basic operations with fuzzy ones
239
253
  self.RELATIONAL_OPS = {
240
- '>=': logic.greater_equal(),
241
- '<=': logic.less_equal(),
242
- '<': logic.less(),
243
- '>': logic.greater(),
244
- '==': logic.equal(),
245
- '~=': logic.not_equal()
254
+ '>=': logic.greater_equal,
255
+ '<=': logic.less_equal,
256
+ '<': logic.less,
257
+ '>': logic.greater,
258
+ '==': logic.equal,
259
+ '~=': logic.not_equal
246
260
  }
247
- self.LOGICAL_NOT = logic.logical_not()
261
+ self.LOGICAL_NOT = logic.logical_not
248
262
  self.LOGICAL_OPS = {
249
- '^': logic.logical_and(),
250
- '&': logic.logical_and(),
251
- '|': logic.logical_or(),
252
- '~': logic.xor(),
253
- '=>': logic.implies(),
254
- '<=>': logic.equiv()
263
+ '^': logic.logical_and,
264
+ '&': logic.logical_and,
265
+ '|': logic.logical_or,
266
+ '~': logic.xor,
267
+ '=>': logic.implies,
268
+ '<=>': logic.equiv
255
269
  }
256
- self.AGGREGATION_OPS['forall'] = logic.forall()
257
- self.AGGREGATION_OPS['exists'] = logic.exists()
258
- self.AGGREGATION_OPS['argmin'] = logic.argmin()
259
- self.AGGREGATION_OPS['argmax'] = logic.argmax()
260
- self.KNOWN_UNARY['sgn'] = logic.sgn()
261
- self.KNOWN_UNARY['floor'] = logic.floor()
262
- self.KNOWN_UNARY['ceil'] = logic.ceil()
263
- self.KNOWN_UNARY['round'] = logic.round()
264
- self.KNOWN_UNARY['sqrt'] = logic.sqrt()
265
- self.KNOWN_BINARY['div'] = logic.div()
266
- self.KNOWN_BINARY['mod'] = logic.mod()
267
- self.KNOWN_BINARY['fmod'] = logic.mod()
268
- self.IF_HELPER = logic.control_if()
269
- self.SWITCH_HELPER = logic.control_switch()
270
- self.BERNOULLI_HELPER = logic.bernoulli()
271
- self.DISCRETE_HELPER = logic.discrete()
272
- self.POISSON_HELPER = logic.poisson()
273
- self.GEOMETRIC_HELPER = logic.geometric()
274
-
275
- def _jax_stop_grad(self, jax_expr):
276
-
270
+ self.AGGREGATION_OPS['forall'] = logic.forall
271
+ self.AGGREGATION_OPS['exists'] = logic.exists
272
+ self.AGGREGATION_OPS['argmin'] = logic.argmin
273
+ self.AGGREGATION_OPS['argmax'] = logic.argmax
274
+ self.KNOWN_UNARY['sgn'] = logic.sgn
275
+ self.KNOWN_UNARY['floor'] = logic.floor
276
+ self.KNOWN_UNARY['ceil'] = logic.ceil
277
+ self.KNOWN_UNARY['round'] = logic.round
278
+ self.KNOWN_UNARY['sqrt'] = logic.sqrt
279
+ self.KNOWN_BINARY['div'] = logic.div
280
+ self.KNOWN_BINARY['mod'] = logic.mod
281
+ self.KNOWN_BINARY['fmod'] = logic.mod
282
+ self.IF_HELPER = logic.control_if
283
+ self.SWITCH_HELPER = logic.control_switch
284
+ self.BERNOULLI_HELPER = logic.bernoulli
285
+ self.DISCRETE_HELPER = logic.discrete
286
+ self.POISSON_HELPER = logic.poisson
287
+ self.GEOMETRIC_HELPER = logic.geometric
288
+
289
+ def _jax_stop_grad(self, jax_expr):
277
290
  def _jax_wrapped_stop_grad(x, params, key):
278
- sample, key, error = jax_expr(x, params, key)
291
+ sample, key, error, params = jax_expr(x, params, key)
279
292
  sample = jax.lax.stop_gradient(sample)
280
- return sample, key, error
281
-
293
+ return sample, key, error, params
282
294
  return _jax_wrapped_stop_grad
283
295
 
284
- def _compile_cpfs(self, info):
296
+ def _compile_cpfs(self, init_params):
285
297
  cpfs_cast = set()
286
298
  jax_cpfs = {}
287
299
  for (_, cpfs) in self.levels.items():
288
300
  for cpf in cpfs:
289
301
  _, expr = self.rddl.cpfs[cpf]
290
- jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
302
+ jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
291
303
  if self.rddl.variable_ranges[cpf] != 'real':
292
304
  cpfs_cast.add(cpf)
293
305
  if cpf in self.cpfs_without_grad:
@@ -298,17 +310,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
298
310
  f'{cpfs_cast} be cast to float.')
299
311
  if self.cpfs_without_grad:
300
312
  raise_warning(f'User requested that gradients not flow '
301
- f'through CPFs {self.cpfs_without_grad}.')
313
+ f'through CPFs {self.cpfs_without_grad}.')
314
+
302
315
  return jax_cpfs
303
316
 
304
- def _jax_kron(self, expr, info):
305
- if self.logic.verbose:
306
- raise_warning('JAX gradient compiler ignores KronDelta '
307
- 'during compilation.')
317
+ def _jax_kron(self, expr, init_params):
308
318
  arg, = expr.args
309
- arg = self._jax(arg, info)
319
+ arg = self._jax(arg, init_params)
310
320
  return arg
311
321
 
322
+
312
323
  # ***********************************************************************
313
324
  # ALL VERSIONS OF JAX PLANS
314
325
  #
@@ -329,7 +340,7 @@ class JaxPlan:
329
340
  self.bounds = None
330
341
 
331
342
  def summarize_hyperparameters(self) -> None:
332
- pass
343
+ print(self.__str__())
333
344
 
334
345
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
335
346
  _bounds: Bounds,
@@ -455,21 +466,21 @@ class JaxStraightLinePlan(JaxPlan):
455
466
  self._wrap_softmax = wrap_softmax
456
467
  self._use_new_projection = use_new_projection
457
468
  self._max_constraint_iter = max_constraint_iter
458
-
459
- def summarize_hyperparameters(self) -> None:
469
+
470
+ def __str__(self) -> str:
460
471
  bounds = '\n '.join(
461
472
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
462
- print(f'policy hyper-parameters:\n'
463
- f' initializer ={self._initializer_base}\n'
464
- f'constraint-sat strategy (simple):\n'
465
- f' parsed_action_bounds =\n {bounds}\n'
466
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
467
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
468
- f' wrap_non_bool ={self._wrap_non_bool}\n'
469
- f'constraint-sat strategy (complex):\n'
470
- f' wrap_softmax ={self._wrap_softmax}\n'
471
- f' use_new_projection ={self._use_new_projection}\n'
472
- f' max_projection_iters ={self._max_constraint_iter}')
473
+ return (f'policy hyper-parameters:\n'
474
+ f' initializer ={self._initializer_base}\n'
475
+ f'constraint-sat strategy (simple):\n'
476
+ f' parsed_action_bounds =\n {bounds}\n'
477
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
478
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
479
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
480
+ f'constraint-sat strategy (complex):\n'
481
+ f' wrap_softmax ={self._wrap_softmax}\n'
482
+ f' use_new_projection ={self._use_new_projection}\n'
483
+ f' max_projection_iters ={self._max_constraint_iter}')
473
484
 
474
485
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
475
486
  _bounds: Bounds,
@@ -820,21 +831,21 @@ class JaxDeepReactivePolicy(JaxPlan):
820
831
  normalizer_kwargs = {'create_offset': True, 'create_scale': True}
821
832
  self._normalizer_kwargs = normalizer_kwargs
822
833
  self._wrap_non_bool = wrap_non_bool
823
-
824
- def summarize_hyperparameters(self) -> None:
834
+
835
+ def __str__(self) -> str:
825
836
  bounds = '\n '.join(
826
837
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
827
- print(f'policy hyper-parameters:\n'
828
- f' topology ={self._topology}\n'
829
- f' activation_fn ={self._activations[0].__name__}\n'
830
- f' initializer ={type(self._initializer_base).__name__}\n'
831
- f' apply_input_norm ={self._normalize}\n'
832
- f' input_norm_layerwise={self._normalize_per_layer}\n'
833
- f' input_norm_args ={self._normalizer_kwargs}\n'
834
- f'constraint-sat strategy:\n'
835
- f' parsed_action_bounds=\n {bounds}\n'
836
- f' wrap_non_bool ={self._wrap_non_bool}')
837
-
838
+ return (f'policy hyper-parameters:\n'
839
+ f' topology ={self._topology}\n'
840
+ f' activation_fn ={self._activations[0].__name__}\n'
841
+ f' initializer ={type(self._initializer_base).__name__}\n'
842
+ f' apply_input_norm ={self._normalize}\n'
843
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
844
+ f' input_norm_args ={self._normalizer_kwargs}\n'
845
+ f'constraint-sat strategy:\n'
846
+ f' parsed_action_bounds=\n {bounds}\n'
847
+ f' wrap_non_bool ={self._wrap_non_bool}')
848
+
838
849
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
839
850
  _bounds: Bounds,
840
851
  horizon: int) -> None:
@@ -1057,6 +1068,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1057
1068
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1058
1069
  return params
1059
1070
 
1071
+
1060
1072
  # ***********************************************************************
1061
1073
  # ALL VERSIONS OF JAX PLANNER
1062
1074
  #
@@ -1084,113 +1096,6 @@ class RollingMean:
1084
1096
  return self._total / len(memory)
1085
1097
 
1086
1098
 
1087
- class JaxPlannerPlot:
1088
- '''Supports plotting and visualization of a JAX policy in real time.'''
1089
-
1090
- def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1091
- show_violin: bool=True, show_action: bool=True) -> None:
1092
- '''Creates a new planner visualizer.
1093
-
1094
- :param rddl: the planning model to optimize
1095
- :param horizon: the lookahead or planning horizon
1096
- :param show_violin: whether to show the distribution of batch losses
1097
- :param show_action: whether to show heatmaps of the action fluents
1098
- '''
1099
- num_plots = 1
1100
- if show_violin:
1101
- num_plots += 1
1102
- if show_action:
1103
- num_plots += len(rddl.action_fluents)
1104
- self._fig, axes = plt.subplots(num_plots)
1105
- if num_plots == 1:
1106
- axes = [axes]
1107
-
1108
- # prepare the loss plot
1109
- self._loss_ax = axes[0]
1110
- self._loss_ax.autoscale(enable=True)
1111
- self._loss_ax.set_xlabel('training time')
1112
- self._loss_ax.set_ylabel('loss value')
1113
- self._loss_plot = self._loss_ax.plot(
1114
- [], [], linestyle=':', marker='o', markersize=2)[0]
1115
- self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1116
-
1117
- # prepare the violin plot
1118
- if show_violin:
1119
- self._hist_ax = axes[1]
1120
- else:
1121
- self._hist_ax = None
1122
-
1123
- # prepare the action plots
1124
- if show_action:
1125
- self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1126
- for (idx, name) in enumerate(rddl.action_fluents)}
1127
- self._action_plots = {}
1128
- for name in rddl.action_fluents:
1129
- ax = self._action_ax[name]
1130
- if rddl.variable_ranges[name] == 'bool':
1131
- vmin, vmax = 0.0, 1.0
1132
- else:
1133
- vmin, vmax = None, None
1134
- action_dim = 1
1135
- for dim in rddl.object_counts(rddl.variable_params[name]):
1136
- action_dim *= dim
1137
- action_plot = ax.pcolormesh(
1138
- np.zeros((action_dim, horizon)),
1139
- cmap='seismic', vmin=vmin, vmax=vmax)
1140
- ax.set_aspect('auto')
1141
- ax.set_xlabel('decision epoch')
1142
- ax.set_ylabel(name)
1143
- plt.colorbar(action_plot, ax=ax)
1144
- self._action_plots[name] = action_plot
1145
- self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1146
- for (name, ax) in self._action_ax.items()}
1147
- else:
1148
- self._action_ax = None
1149
- self._action_plots = None
1150
- self._action_back = None
1151
-
1152
- plt.tight_layout()
1153
- plt.show(block=False)
1154
-
1155
- def redraw(self, xticks, losses, actions, returns) -> None:
1156
-
1157
- # draw the loss curve
1158
- self._fig.canvas.restore_region(self._loss_back)
1159
- self._loss_plot.set_xdata(xticks)
1160
- self._loss_plot.set_ydata(losses)
1161
- self._loss_ax.set_xlim([0, len(xticks)])
1162
- self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
1163
- self._loss_ax.draw_artist(self._loss_plot)
1164
- self._fig.canvas.blit(self._loss_ax.bbox)
1165
-
1166
- # draw the violin plot
1167
- if self._hist_ax is not None:
1168
- self._hist_ax.clear()
1169
- self._hist_ax.set_xlabel('loss value')
1170
- self._hist_ax.set_ylabel('density')
1171
- self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1172
-
1173
- # draw the actions
1174
- if self._action_ax is not None:
1175
- for (name, values) in actions.items():
1176
- values = np.mean(values, axis=0, dtype=float)
1177
- values = np.reshape(values, newshape=(values.shape[0], -1)).T
1178
- self._fig.canvas.restore_region(self._action_back[name])
1179
- self._action_plots[name].set_array(values)
1180
- self._action_ax[name].draw_artist(self._action_plots[name])
1181
- self._fig.canvas.blit(self._action_ax[name].bbox)
1182
- self._action_plots[name].set_clim([np.min(values), np.max(values)])
1183
-
1184
- self._fig.canvas.draw()
1185
- self._fig.canvas.flush_events()
1186
-
1187
- def close(self) -> None:
1188
- plt.close(self._fig)
1189
- del self._loss_ax, self._hist_ax, self._action_ax, \
1190
- self._loss_plot, self._action_plots, self._fig, \
1191
- self._loss_back, self._action_back
1192
-
1193
-
1194
1099
  class JaxPlannerStatus(Enum):
1195
1100
  '''Represents the status of a policy update from the JAX planner,
1196
1101
  including whether the update resulted in nan gradient,
@@ -1198,14 +1103,15 @@ class JaxPlannerStatus(Enum):
1198
1103
  can be used to monitor and act based on the planner's progress.'''
1199
1104
 
1200
1105
  NORMAL = 0
1201
- NO_PROGRESS = 1
1202
- PRECONDITION_POSSIBLY_UNSATISFIED = 2
1203
- INVALID_GRADIENT = 3
1204
- TIME_BUDGET_REACHED = 4
1205
- ITER_BUDGET_REACHED = 5
1106
+ STOPPING_RULE_REACHED = 1
1107
+ NO_PROGRESS = 2
1108
+ PRECONDITION_POSSIBLY_UNSATISFIED = 3
1109
+ INVALID_GRADIENT = 4
1110
+ TIME_BUDGET_REACHED = 5
1111
+ ITER_BUDGET_REACHED = 6
1206
1112
 
1207
- def is_failure(self) -> bool:
1208
- return self.value >= 3
1113
+ def is_terminal(self) -> bool:
1114
+ return self.value == 1 or self.value >= 4
1209
1115
 
1210
1116
 
1211
1117
  class JaxPlannerStoppingRule:
@@ -1255,15 +1161,16 @@ class JaxBackpropPlanner:
1255
1161
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1256
1162
  optimizer_kwargs: Optional[Kwargs]=None,
1257
1163
  clip_grad: Optional[float]=None,
1258
- noise_grad_eta: float=0.0,
1259
- noise_grad_gamma: float=1.0,
1260
- logic: FuzzyLogic=FuzzyLogic(),
1164
+ line_search_kwargs: Optional[Kwargs]=None,
1165
+ noise_kwargs: Optional[Kwargs]=None,
1166
+ logic: Logic=FuzzyLogic(),
1261
1167
  use_symlog_reward: bool=False,
1262
1168
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
1263
1169
  utility_kwargs: Optional[Kwargs]=None,
1264
1170
  cpfs_without_grad: Optional[Set[str]]=None,
1265
1171
  compile_non_fluent_exact: bool=True,
1266
- logger: Optional[Logger]=None) -> None:
1172
+ logger: Optional[Logger]=None,
1173
+ dashboard_viz: Optional[Any]=None) -> None:
1267
1174
  '''Creates a new gradient-based algorithm for optimizing action sequences
1268
1175
  (plan) in the given RDDL. Some operations will be converted to their
1269
1176
  differentiable counterparts; the specific operations can be customized
@@ -1283,9 +1190,10 @@ class JaxBackpropPlanner:
1283
1190
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
1284
1191
  factory (e.g. which parameters are controllable externally)
1285
1192
  :param clip_grad: maximum magnitude of gradient updates
1286
- :param noise_grad_eta: scale of the gradient noise variance
1287
- :param noise_grad_gamma: decay rate of the gradient noise variance
1288
- :param logic: a subclass of FuzzyLogic for mapping exact mathematical
1193
+ :param line_search_kwargs: parameters to pass to optional line search
1194
+ method to scale learning rate
1195
+ :param noise_kwargs: parameters of optional gradient noise
1196
+ :param logic: a subclass of Logic for mapping exact mathematical
1289
1197
  operations to their differentiable counterparts
1290
1198
  :param use_symlog_reward: whether to use the symlog transform on the
1291
1199
  reward as a form of normalization
@@ -1300,6 +1208,8 @@ class JaxBackpropPlanner:
1300
1208
  :param compile_non_fluent_exact: whether non-fluent expressions
1301
1209
  are always compiled using exact JAX expressions
1302
1210
  :param logger: to log information about compilation to file
1211
+ :param dashboard_viz: optional visualizer object from the environment
1212
+ to pass to the dashboard to visualize the policy
1303
1213
  '''
1304
1214
  self.rddl = rddl
1305
1215
  self.plan = plan
@@ -1314,31 +1224,34 @@ class JaxBackpropPlanner:
1314
1224
  action_bounds = {}
1315
1225
  self._action_bounds = action_bounds
1316
1226
  self.use64bit = use64bit
1317
- self._optimizer_name = optimizer
1227
+ self.optimizer_name = optimizer
1318
1228
  if optimizer_kwargs is None:
1319
1229
  optimizer_kwargs = {'learning_rate': 0.1}
1320
- self._optimizer_kwargs = optimizer_kwargs
1230
+ self.optimizer_kwargs = optimizer_kwargs
1321
1231
  self.clip_grad = clip_grad
1322
- self.noise_grad_eta = noise_grad_eta
1323
- self.noise_grad_gamma = noise_grad_gamma
1232
+ self.line_search_kwargs = line_search_kwargs
1233
+ self.noise_kwargs = noise_kwargs
1324
1234
 
1325
1235
  # set optimizer
1326
1236
  try:
1327
1237
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1328
1238
  except Exception as _:
1329
1239
  raise_warning(
1330
- 'Failed to inject hyperparameters into optax optimizer, '
1240
+ f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
1331
1241
  'rolling back to safer method: please note that modification of '
1332
- 'optimizer hyperparameters will not work, and it is '
1333
- 'recommended to update optax and related packages.', 'red')
1334
- optimizer = optimizer(**optimizer_kwargs)
1335
- if clip_grad is None:
1336
- self.optimizer = optimizer
1337
- else:
1338
- self.optimizer = optax.chain(
1339
- optax.clip(clip_grad),
1340
- optimizer
1341
- )
1242
+ 'optimizer hyperparameters will not work.', 'red')
1243
+ optimizer = optimizer(**optimizer_kwargs)
1244
+
1245
+ # apply optimizer chain of transformations
1246
+ pipeline = []
1247
+ if clip_grad is not None:
1248
+ pipeline.append(optax.clip(clip_grad))
1249
+ if noise_kwargs is not None:
1250
+ pipeline.append(optax.add_noise(**noise_kwargs))
1251
+ pipeline.append(optimizer)
1252
+ if line_search_kwargs is not None:
1253
+ pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
1254
+ self.optimizer = optax.chain(*pipeline)
1342
1255
 
1343
1256
  # set utility
1344
1257
  if isinstance(utility, str):
@@ -1371,11 +1284,12 @@ class JaxBackpropPlanner:
1371
1284
  self.cpfs_without_grad = cpfs_without_grad
1372
1285
  self.compile_non_fluent_exact = compile_non_fluent_exact
1373
1286
  self.logger = logger
1287
+ self.dashboard_viz = dashboard_viz
1374
1288
 
1375
1289
  self._jax_compile_rddl()
1376
1290
  self._jax_compile_optimizer()
1377
1291
 
1378
- def _summarize_system(self) -> None:
1292
+ def summarize_system(self) -> str:
1379
1293
  try:
1380
1294
  jaxlib_version = jax._src.lib.version_str
1381
1295
  except Exception as _:
@@ -1394,36 +1308,55 @@ r"""
1394
1308
  \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1395
1309
  """
1396
1310
 
1397
- print('\n'
1398
- f'{LOGO}\n'
1399
- f'Version {__version__}\n'
1400
- f'Python {sys.version}\n'
1401
- f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1402
- f'optax {optax.__version__}, haiku {hk.__version__}, '
1403
- f'numpy {np.__version__}\n'
1404
- f'devices: {devices_short}\n')
1311
+ return ('\n'
1312
+ f'{LOGO}\n'
1313
+ f'Version {__version__}\n'
1314
+ f'Python {sys.version}\n'
1315
+ f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1316
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1317
+ f'numpy {np.__version__}\n'
1318
+ f'devices: {devices_short}\n')
1319
+
1320
+ def __str__(self) -> str:
1321
+ result = (f'objective hyper-parameters:\n'
1322
+ f' utility_fn ={self.utility.__name__}\n'
1323
+ f' utility args ={self.utility_kwargs}\n'
1324
+ f' use_symlog ={self.use_symlog_reward}\n'
1325
+ f' lookahead ={self.horizon}\n'
1326
+ f' user_action_bounds={self._action_bounds}\n'
1327
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1328
+ f' non_fluents exact ={self.compile_non_fluent_exact}\n'
1329
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1330
+ f'optimizer hyper-parameters:\n'
1331
+ f' use_64_bit ={self.use64bit}\n'
1332
+ f' optimizer ={self.optimizer_name}\n'
1333
+ f' optimizer args ={self.optimizer_kwargs}\n'
1334
+ f' clip_gradient ={self.clip_grad}\n'
1335
+ f' line_search_kwargs={self.line_search_kwargs}\n'
1336
+ f' noise_kwargs ={self.noise_kwargs}\n'
1337
+ f' batch_size_train ={self.batch_size_train}\n'
1338
+ f' batch_size_test ={self.batch_size_test}')
1339
+ result += '\n' + str(self.plan)
1340
+ result += '\n' + str(self.logic)
1341
+
1342
+ # print model relaxation information
1343
+ if not self.compiled.model_params:
1344
+ return result
1345
+ result += '\n' + ('Some RDDL operations are non-differentiable '
1346
+ 'and will be approximated as follows:' + '\n')
1347
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1348
+ for info in self.compiled.model_parameter_info().values():
1349
+ rddl_op = info['rddl_op']
1350
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1351
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1352
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1353
+ result += (f' {rddl_op}:\n'
1354
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1355
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1356
+ return result
1405
1357
 
1406
1358
  def summarize_hyperparameters(self) -> None:
1407
- print(f'objective hyper-parameters:\n'
1408
- f' utility_fn ={self.utility.__name__}\n'
1409
- f' utility args ={self.utility_kwargs}\n'
1410
- f' use_symlog ={self.use_symlog_reward}\n'
1411
- f' lookahead ={self.horizon}\n'
1412
- f' user_action_bounds={self._action_bounds}\n'
1413
- f' fuzzy logic type ={type(self.logic).__name__}\n'
1414
- f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1415
- f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1416
- f'optimizer hyper-parameters:\n'
1417
- f' use_64_bit ={self.use64bit}\n'
1418
- f' optimizer ={self._optimizer_name.__name__}\n'
1419
- f' optimizer args ={self._optimizer_kwargs}\n'
1420
- f' clip_gradient ={self.clip_grad}\n'
1421
- f' noise_grad_eta ={self.noise_grad_eta}\n'
1422
- f' noise_grad_gamma ={self.noise_grad_gamma}\n'
1423
- f' batch_size_train ={self.batch_size_train}\n'
1424
- f' batch_size_test ={self.batch_size_test}')
1425
- self.plan.summarize_hyperparameters()
1426
- self.logic.summarize_hyperparameters()
1359
+ print(self.__str__())
1427
1360
 
1428
1361
  # ===========================================================================
1429
1362
  # COMPILATION SUBROUTINES
@@ -1439,14 +1372,16 @@ r"""
1439
1372
  logger=self.logger,
1440
1373
  use64bit=self.use64bit,
1441
1374
  cpfs_without_grad=self.cpfs_without_grad,
1442
- compile_non_fluent_exact=self.compile_non_fluent_exact)
1375
+ compile_non_fluent_exact=self.compile_non_fluent_exact
1376
+ )
1443
1377
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1444
1378
 
1445
1379
  # Jax compilation of the exact RDDL for testing
1446
1380
  self.test_compiled = JaxRDDLCompiler(
1447
1381
  rddl=rddl,
1448
1382
  logger=self.logger,
1449
- use64bit=self.use64bit)
1383
+ use64bit=self.use64bit
1384
+ )
1450
1385
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1451
1386
 
1452
1387
  def _jax_compile_optimizer(self):
@@ -1462,13 +1397,15 @@ r"""
1462
1397
  train_rollouts = self.compiled.compile_rollouts(
1463
1398
  policy=self.plan.train_policy,
1464
1399
  n_steps=self.horizon,
1465
- n_batch=self.batch_size_train)
1400
+ n_batch=self.batch_size_train
1401
+ )
1466
1402
  self.train_rollouts = train_rollouts
1467
1403
 
1468
1404
  test_rollouts = self.test_compiled.compile_rollouts(
1469
1405
  policy=self.plan.test_policy,
1470
1406
  n_steps=self.horizon,
1471
- n_batch=self.batch_size_test)
1407
+ n_batch=self.batch_size_test
1408
+ )
1472
1409
  self.test_rollouts = jax.jit(test_rollouts)
1473
1410
 
1474
1411
  # initialization
@@ -1503,14 +1440,16 @@ r"""
1503
1440
  _jax_wrapped_returns = self._jax_return(use_symlog)
1504
1441
 
1505
1442
  # the loss is the average cumulative reward across all roll-outs
1506
- def _jax_wrapped_plan_loss(key, policy_params, hyperparams,
1443
+ def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
1507
1444
  subs, model_params):
1508
- log = rollouts(key, policy_params, hyperparams, subs, model_params)
1445
+ log, model_params = rollouts(
1446
+ key, policy_params, policy_hyperparams, subs, model_params)
1509
1447
  rewards = log['reward']
1510
1448
  returns = _jax_wrapped_returns(rewards)
1511
1449
  utility = utility_fn(returns, **utility_kwargs)
1512
1450
  loss = -utility
1513
- return loss, log
1451
+ aux = (log, model_params)
1452
+ return loss, aux
1514
1453
 
1515
1454
  return _jax_wrapped_plan_loss
1516
1455
 
@@ -1518,8 +1457,9 @@ r"""
1518
1457
  init = self.plan.initializer
1519
1458
  optimizer = self.optimizer
1520
1459
 
1521
- def _jax_wrapped_init_policy(key, hyperparams, subs):
1522
- policy_params = init(key, hyperparams, subs)
1460
+ # initialize both the policy and its optimizer
1461
+ def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
1462
+ policy_params = init(key, policy_hyperparams, subs)
1523
1463
  opt_state = optimizer.init(policy_params)
1524
1464
  return policy_params, opt_state, {}
1525
1465
 
@@ -1528,35 +1468,34 @@ r"""
1528
1468
  def _jax_update(self, loss):
1529
1469
  optimizer = self.optimizer
1530
1470
  projection = self.plan.projection
1531
-
1532
- # add Gaussian gradient noise per Neelakantan et al., 2016.
1533
- def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
1534
- treedef = jax.tree_util.tree_structure(grads)
1535
- keys_flat = random.split(key, num=treedef.num_leaves)
1536
- keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
1537
- new_grads = jax.tree_map(
1538
- lambda g, k: g + sigma * random.normal(
1539
- key=k, shape=g.shape, dtype=g.dtype),
1540
- grads,
1541
- keys_tree
1542
- )
1543
- return new_grads
1471
+ use_ls = self.line_search_kwargs is not None
1544
1472
 
1545
1473
  # calculate the plan gradient w.r.t. return loss and update optimizer
1546
1474
  # also perform a projection step to satisfy constraints on actions
1547
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
1475
+ def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
1476
+ subs, model_params):
1477
+ return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
1478
+
1479
+ def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
1548
1480
  subs, model_params, opt_state, opt_aux):
1549
1481
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1550
- (loss_val, log), grad = grad_fn(
1551
- key, policy_params, hyperparams, subs, model_params)
1552
- sigma = opt_aux.get('noise_sigma', 0.0)
1553
- grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
1554
- updates, opt_state = optimizer.update(grad, opt_state)
1482
+ (loss_val, (log, model_params)), grad = grad_fn(
1483
+ key, policy_params, policy_hyperparams, subs, model_params)
1484
+ if use_ls:
1485
+ updates, opt_state = optimizer.update(
1486
+ grad, opt_state, params=policy_params,
1487
+ value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
1488
+ key=key, policy_hyperparams=policy_hyperparams, subs=subs,
1489
+ model_params=model_params)
1490
+ else:
1491
+ updates, opt_state = optimizer.update(
1492
+ grad, opt_state, params=policy_params)
1555
1493
  policy_params = optax.apply_updates(policy_params, updates)
1556
- policy_params, converged = projection(policy_params, hyperparams)
1494
+ policy_params, converged = projection(policy_params, policy_hyperparams)
1557
1495
  log['grad'] = grad
1558
1496
  log['updates'] = updates
1559
- return policy_params, converged, opt_state, opt_aux, loss_val, log
1497
+ return policy_params, converged, opt_state, opt_aux, \
1498
+ loss_val, log, model_params
1560
1499
 
1561
1500
  return jax.jit(_jax_wrapped_plan_update)
1562
1501
 
@@ -1583,7 +1522,6 @@ r"""
1583
1522
  for (state, next_state) in rddl.next_state.items():
1584
1523
  init_train[next_state] = init_train[state]
1585
1524
  init_test[next_state] = init_test[state]
1586
-
1587
1525
  return init_train, init_test
1588
1526
 
1589
1527
  def as_optimization_problem(
@@ -1637,38 +1575,40 @@ r"""
1637
1575
  loss_fn = self._jax_loss(self.train_rollouts)
1638
1576
 
1639
1577
  @jax.jit
1640
- def _loss_with_key(key, params_1d):
1578
+ def _loss_with_key(key, params_1d, model_params):
1641
1579
  policy_params = unravel_fn(params_1d)
1642
- loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1643
- train_subs, model_params)
1644
- return loss_val
1580
+ loss_val, (_, model_params) = loss_fn(
1581
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1582
+ return loss_val, model_params
1645
1583
 
1646
1584
  @jax.jit
1647
- def _grad_with_key(key, params_1d):
1585
+ def _grad_with_key(key, params_1d, model_params):
1648
1586
  policy_params = unravel_fn(params_1d)
1649
1587
  grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1650
- grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1651
- train_subs, model_params)
1652
- grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1653
- return grad_1d
1588
+ grad_val, (_, model_params) = grad_fn(
1589
+ key, policy_params, policy_hyperparams, train_subs, model_params)
1590
+ grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
1591
+ return grad_val, model_params
1654
1592
 
1655
1593
  def _loss_function(params_1d):
1656
1594
  nonlocal key
1595
+ nonlocal model_params
1657
1596
  if loss_function_updates_key:
1658
1597
  key, subkey = random.split(key)
1659
1598
  else:
1660
1599
  subkey = key
1661
- loss_val = _loss_with_key(subkey, params_1d)
1600
+ loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
1662
1601
  loss_val = float(loss_val)
1663
1602
  return loss_val
1664
1603
 
1665
1604
  def _grad_function(params_1d):
1666
1605
  nonlocal key
1606
+ nonlocal model_params
1667
1607
  if grad_function_updates_key:
1668
1608
  key, subkey = random.split(key)
1669
1609
  else:
1670
1610
  subkey = key
1671
- grad = _grad_with_key(subkey, params_1d)
1611
+ grad, model_params = _grad_with_key(subkey, params_1d, model_params)
1672
1612
  grad = np.asarray(grad)
1673
1613
  return grad
1674
1614
 
@@ -1683,9 +1623,9 @@ r"""
1683
1623
 
1684
1624
  :param key: JAX PRNG key (derived from clock if not provided)
1685
1625
  :param epochs: the maximum number of steps of gradient descent
1686
- :param train_seconds: total time allocated for gradient descent
1687
- :param plot_step: frequency to plot the plan and save result to disk
1688
- :param plot_kwargs: additional arguments to pass to the plotter
1626
+ :param train_seconds: total time allocated for gradient descent
1627
+ :param dashboard: dashboard to display training results
1628
+ :param dashboard_id: experiment id for the dashboard
1689
1629
  :param model_params: optional model-parameters to override default
1690
1630
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1691
1631
  weights for sigmoid wrapping boolean actions
@@ -1721,8 +1661,8 @@ r"""
1721
1661
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1722
1662
  epochs: int=999999,
1723
1663
  train_seconds: float=120.,
1724
- plot_step: Optional[int]=None,
1725
- plot_kwargs: Optional[Kwargs]=None,
1664
+ dashboard: Optional[JaxPlannerDashboard]=None,
1665
+ dashboard_id: Optional[str]=None,
1726
1666
  model_params: Optional[Dict[str, Any]]=None,
1727
1667
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1728
1668
  subs: Optional[Dict[str, Any]]=None,
@@ -1738,9 +1678,9 @@ r"""
1738
1678
 
1739
1679
  :param key: JAX PRNG key (derived from clock if not provided)
1740
1680
  :param epochs: the maximum number of steps of gradient descent
1741
- :param train_seconds: total time allocated for gradient descent
1742
- :param plot_step: frequency to plot the plan and save result to disk
1743
- :param plot_kwargs: additional arguments to pass to the plotter
1681
+ :param train_seconds: total time allocated for gradient descent
1682
+ :param dashboard: dashboard to display training results
1683
+ :param dashboard_id: experiment id for the dashboard
1744
1684
  :param model_params: optional model-parameters to override default
1745
1685
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1746
1686
  weights for sigmoid wrapping boolean actions
@@ -1760,9 +1700,14 @@ r"""
1760
1700
  start_time = time.time()
1761
1701
  elapsed_outside_loop = 0
1762
1702
 
1703
+ # ======================================================================
1704
+ # INITIALIZATION OF HYPER-PARAMETERS
1705
+ # ======================================================================
1706
+
1763
1707
  # if PRNG key is not provided
1764
1708
  if key is None:
1765
1709
  key = random.PRNGKey(round(time.time() * 1000))
1710
+ dash_key = key[1].item()
1766
1711
 
1767
1712
  # if policy_hyperparams is not provided
1768
1713
  if policy_hyperparams is None:
@@ -1789,7 +1734,7 @@ r"""
1789
1734
 
1790
1735
  # print summary of parameters:
1791
1736
  if print_summary:
1792
- self._summarize_system()
1737
+ print(self.summarize_system())
1793
1738
  self.summarize_hyperparameters()
1794
1739
  print(f'optimize() call hyper-parameters:\n'
1795
1740
  f' PRNG key ={key}\n'
@@ -1800,16 +1745,16 @@ r"""
1800
1745
  f' override_subs_dict ={subs is not None}\n'
1801
1746
  f' provide_param_guess={guess is not None}\n'
1802
1747
  f' test_rolling_window={test_rolling_window}\n'
1803
- f' plot_frequency ={plot_step}\n'
1804
- f' plot_kwargs ={plot_kwargs}\n'
1748
+ f' dashboard ={dashboard is not None}\n'
1749
+ f' dashboard_id ={dashboard_id}\n'
1805
1750
  f' print_summary ={print_summary}\n'
1806
1751
  f' print_progress ={print_progress}\n'
1807
1752
  f' stopping_rule ={stopping_rule}\n')
1808
- if self.compiled.relaxations:
1809
- print('Some RDDL operations are non-differentiable, '
1810
- 'they will be approximated as follows:')
1811
- print(self.compiled.summarize_model_relaxations())
1812
-
1753
+
1754
+ # ======================================================================
1755
+ # INITIALIZATION OF STATE AND POLICY
1756
+ # ======================================================================
1757
+
1813
1758
  # compute a batched version of the initial values
1814
1759
  if subs is None:
1815
1760
  subs = self.test_compiled.init_values
@@ -1841,6 +1786,10 @@ r"""
1841
1786
  policy_params = guess
1842
1787
  opt_state = self.optimizer.init(policy_params)
1843
1788
  opt_aux = {}
1789
+
1790
+ # ======================================================================
1791
+ # INITIALIZATION OF RUNNING STATISTICS
1792
+ # ======================================================================
1844
1793
 
1845
1794
  # initialize running statistics
1846
1795
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
@@ -1854,35 +1803,39 @@ r"""
1854
1803
  if stopping_rule is not None:
1855
1804
  stopping_rule.reset()
1856
1805
 
1857
- # initialize plot area
1858
- if plot_step is None or plot_step <= 0 or plt is None:
1859
- plot = None
1860
- else:
1861
- if plot_kwargs is None:
1862
- plot_kwargs = {}
1863
- plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1864
- xticks, loss_values = [], []
1806
+ # initialize dash board
1807
+ if dashboard is not None:
1808
+ dashboard_id = dashboard.register_experiment(
1809
+ dashboard_id, dashboard.get_planner_info(self),
1810
+ key=dash_key, viz=self.dashboard_viz)
1811
+
1812
+ # ======================================================================
1813
+ # MAIN TRAINING LOOP BEGINS
1814
+ # ======================================================================
1865
1815
 
1866
- # training loop
1867
1816
  iters = range(epochs)
1868
1817
  if print_progress:
1869
1818
  iters = tqdm(iters, total=100, position=tqdm_position)
1870
1819
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1871
1820
 
1872
1821
  for it in iters:
1873
- status = JaxPlannerStatus.NORMAL
1874
1822
 
1875
- # gradient noise schedule
1876
- noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
1877
- noise_sigma = np.sqrt(noise_var)
1878
- opt_aux['noise_sigma'] = noise_sigma
1823
+ # ==================================================================
1824
+ # NEXT GRADIENT DESCENT STEP
1825
+ # ==================================================================
1826
+
1827
+ status = JaxPlannerStatus.NORMAL
1879
1828
 
1880
1829
  # update the parameters of the plan
1881
1830
  key, subkey = random.split(key)
1882
- policy_params, converged, opt_state, opt_aux, \
1883
- train_loss, train_log = \
1831
+ (policy_params, converged, opt_state, opt_aux,
1832
+ train_loss, train_log, model_params) = \
1884
1833
  self.update(subkey, policy_params, policy_hyperparams,
1885
1834
  train_subs, model_params, opt_state, opt_aux)
1835
+
1836
+ # ==================================================================
1837
+ # STATUS CHECKS AND LOGGING
1838
+ # ==================================================================
1886
1839
 
1887
1840
  # no progress
1888
1841
  grad_norm_zero, _ = jax.tree_util.tree_flatten(
@@ -1901,12 +1854,11 @@ r"""
1901
1854
  # numerical error
1902
1855
  if not np.isfinite(train_loss):
1903
1856
  raise_warning(
1904
- f'Aborting JAX planner due to invalid train loss {train_loss}.',
1905
- 'red')
1857
+ f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
1906
1858
  status = JaxPlannerStatus.INVALID_GRADIENT
1907
1859
 
1908
1860
  # evaluate test losses and record best plan so far
1909
- test_loss, log = self.test_loss(
1861
+ test_loss, (log, model_params_test) = self.test_loss(
1910
1862
  subkey, policy_params, policy_hyperparams,
1911
1863
  test_subs, model_params_test)
1912
1864
  test_loss = rolling_test_loss.update(test_loss)
@@ -1915,32 +1867,15 @@ r"""
1915
1867
  policy_params, test_loss, train_log['grad']
1916
1868
  last_iter_improve = it
1917
1869
 
1918
- # save the plan figure
1919
- if plot is not None and it % plot_step == 0:
1920
- xticks.append(it // plot_step)
1921
- loss_values.append(test_loss.item())
1922
- action_values = {name: values
1923
- for (name, values) in log['fluents'].items()
1924
- if name in self.rddl.action_fluents}
1925
- returns = -np.sum(np.asarray(log['reward']), axis=1)
1926
- plot.redraw(xticks, loss_values, action_values, returns)
1927
-
1928
- # if the progress bar is used
1929
- elapsed = time.time() - start_time - elapsed_outside_loop
1930
- if print_progress:
1931
- iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1932
- iters.set_description(
1933
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1934
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1935
- f'{status.value} status')
1936
-
1937
1870
  # reached computation budget
1871
+ elapsed = time.time() - start_time - elapsed_outside_loop
1938
1872
  if elapsed >= train_seconds:
1939
1873
  status = JaxPlannerStatus.TIME_BUDGET_REACHED
1940
1874
  if it >= epochs - 1:
1941
1875
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1942
1876
 
1943
- # return a callback
1877
+ # build a callback
1878
+ progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1944
1879
  callback = {
1945
1880
  'status': status,
1946
1881
  'iteration': it,
@@ -1952,29 +1887,48 @@ r"""
1952
1887
  'last_iteration_improved': last_iter_improve,
1953
1888
  'grad': train_log['grad'],
1954
1889
  'best_grad': best_grad,
1955
- 'noise_sigma': noise_sigma,
1956
1890
  'updates': train_log['updates'],
1957
1891
  'elapsed_time': elapsed,
1958
1892
  'key': key,
1893
+ 'model_params': model_params,
1894
+ 'progress': progress_percent,
1895
+ 'train_log': train_log,
1959
1896
  **log
1960
1897
  }
1898
+
1899
+ # stopping condition reached
1900
+ if stopping_rule is not None and stopping_rule.monitor(callback):
1901
+ callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
1902
+
1903
+ # if the progress bar is used
1904
+ if print_progress:
1905
+ iters.n = progress_percent
1906
+ iters.set_description(
1907
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1908
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1909
+ f'{status.value} status'
1910
+ )
1911
+
1912
+ # dash-board
1913
+ if dashboard is not None:
1914
+ dashboard.update_experiment(dashboard_id, callback)
1915
+
1916
+ # yield the callback
1961
1917
  start_time_outside = time.time()
1962
1918
  yield callback
1963
1919
  elapsed_outside_loop += (time.time() - start_time_outside)
1964
1920
 
1965
1921
  # abortion check
1966
- if status.is_failure():
1967
- break
1968
-
1969
- # stopping condition reached
1970
- if stopping_rule is not None and stopping_rule.monitor(callback):
1971
- break
1972
-
1922
+ if status.is_terminal():
1923
+ break
1924
+
1925
+ # ======================================================================
1926
+ # POST-PROCESSING AND CLEANUP
1927
+ # ======================================================================
1928
+
1973
1929
  # release resources
1974
1930
  if print_progress:
1975
1931
  iters.close()
1976
- if plot is not None:
1977
- plot.close()
1978
1932
 
1979
1933
  # validate the test return
1980
1934
  if log:
@@ -2098,101 +2052,6 @@ r"""
2098
2052
  return actions
2099
2053
 
2100
2054
 
2101
- class JaxLineSearchPlanner(JaxBackpropPlanner):
2102
- '''A class for optimizing an action sequence in the given RDDL MDP using
2103
- linear search gradient descent, with the Armijo condition.'''
2104
-
2105
- def __init__(self, *args,
2106
- decay: float=0.8,
2107
- c: float=0.1,
2108
- step_max: float=1.0,
2109
- step_min: float=1e-6,
2110
- **kwargs) -> None:
2111
- '''Creates a new gradient-based algorithm for optimizing action sequences
2112
- (plan) in the given RDDL using line search. All arguments are the
2113
- same as in the parent class, except:
2114
-
2115
- :param decay: reduction factor of learning rate per line search iteration
2116
- :param c: positive coefficient in Armijo condition, should be in (0, 1)
2117
- :param step_max: initial learning rate for line search
2118
- :param step_min: minimum possible learning rate (line search halts)
2119
- '''
2120
- self.decay = decay
2121
- self.c = c
2122
- self.step_max = step_max
2123
- self.step_min = step_min
2124
- if 'clip_grad' in kwargs:
2125
- raise_warning('clip_grad parameter conflicts with '
2126
- 'line search planner and will be ignored.', 'red')
2127
- del kwargs['clip_grad']
2128
- super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
2129
-
2130
- def summarize_hyperparameters(self) -> None:
2131
- super(JaxLineSearchPlanner, self).summarize_hyperparameters()
2132
- print(f'linesearch hyper-parameters:\n'
2133
- f' decay ={self.decay}\n'
2134
- f' c ={self.c}\n'
2135
- f' lr_range=({self.step_min}, {self.step_max})')
2136
-
2137
- def _jax_update(self, loss):
2138
- optimizer = self.optimizer
2139
- projection = self.plan.projection
2140
- decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
2141
-
2142
- # initialize the line search routine
2143
- @jax.jit
2144
- def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
2145
- subs, model_params):
2146
- (f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
2147
- key, policy_params, hyperparams, subs, model_params)
2148
- gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
2149
- gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
2150
- log['grad'] = grad
2151
- return f, grad, gnorm2, log
2152
-
2153
- # compute the next trial solution
2154
- @jax.jit
2155
- def _jax_wrapped_line_search_trial(
2156
- step, grad, key, params, hparams, subs, mparams, state):
2157
- state.hyperparams['learning_rate'] = step
2158
- updates, new_state = optimizer.update(grad, state)
2159
- new_params = optax.apply_updates(params, updates)
2160
- new_params, _ = projection(new_params, hparams)
2161
- f_step, _ = loss(key, new_params, hparams, subs, mparams)
2162
- return f_step, new_params, new_state
2163
-
2164
- # main iteration of line search
2165
- def _jax_wrapped_plan_update(key, policy_params, hyperparams,
2166
- subs, model_params, opt_state, opt_aux):
2167
-
2168
- # initialize the line search
2169
- f, grad, gnorm2, log = _jax_wrapped_line_search_init(
2170
- key, policy_params, hyperparams, subs, model_params)
2171
-
2172
- # continue to reduce the learning rate until the Armijo condition holds
2173
- trials = 0
2174
- step = lrmax / decay
2175
- f_step = np.inf
2176
- best_f, best_step, best_params, best_state = np.inf, None, None, None
2177
- while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2178
- or not trials:
2179
- trials += 1
2180
- step *= decay
2181
- f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2182
- step, grad, key, policy_params, hyperparams, subs,
2183
- model_params, opt_state)
2184
- if f_step < best_f:
2185
- best_f, best_step, best_params, best_state = \
2186
- f_step, step, new_params, new_state
2187
-
2188
- log['updates'] = None
2189
- log['line_search_iters'] = trials
2190
- log['learning_rate'] = best_step
2191
- opt_aux['best_step'] = best_step
2192
- return best_params, True, best_state, opt_aux, best_f, log
2193
-
2194
- return _jax_wrapped_plan_update
2195
-
2196
2055
  # ***********************************************************************
2197
2056
  # ALL VERSIONS OF RISK FUNCTIONS
2198
2057
  #
@@ -2224,6 +2083,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2224
2083
  returns <= jnp.percentile(returns, q=100 * alpha))
2225
2084
  return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2226
2085
 
2086
+
2227
2087
  # ***********************************************************************
2228
2088
  # ALL VERSIONS OF CONTROLLERS
2229
2089
  #
@@ -2268,8 +2128,10 @@ class JaxOfflineController(BaseAgent):
2268
2128
  self.params_given = params is not None
2269
2129
 
2270
2130
  self.step = 0
2131
+ self.callback = None
2271
2132
  if not self.train_on_reset and not self.params_given:
2272
2133
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2134
+ self.callback = callback
2273
2135
  params = callback['best_params']
2274
2136
  self.params = params
2275
2137
 
@@ -2284,6 +2146,7 @@ class JaxOfflineController(BaseAgent):
2284
2146
  self.step = 0
2285
2147
  if self.train_on_reset and not self.params_given:
2286
2148
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2149
+ self.callback = callback
2287
2150
  self.params = callback['best_params']
2288
2151
 
2289
2152
 
@@ -2326,7 +2189,9 @@ class JaxOnlineController(BaseAgent):
2326
2189
  key=self.key,
2327
2190
  guess=self.guess,
2328
2191
  subs=state,
2329
- **self.train_kwargs)
2192
+ **self.train_kwargs
2193
+ )
2194
+ self.callback = callback
2330
2195
  params = callback['best_params']
2331
2196
  self.key, subkey = random.split(self.key)
2332
2197
  actions = planner.get_action(
@@ -2337,4 +2202,5 @@ class JaxOnlineController(BaseAgent):
2337
2202
 
2338
2203
  def reset(self) -> None:
2339
2204
  self.guess = None
2205
+ self.callback = None
2340
2206