pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,43 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # RELEVANT SOURCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
14
+ # Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
15
+ # the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
16
+ #
17
+ # [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
18
+ # reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
19
+ # AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
20
+ #
21
+ # [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
22
+ # nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
23
+ #
24
+ # [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
25
+ # policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
26
+ # Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
27
+ #
28
+ # ***********************************************************************
29
+
30
+
1
31
  from ast import literal_eval
2
32
  from collections import deque
3
33
  import configparser
4
34
  from enum import Enum
35
+ from functools import partial
5
36
  import os
6
37
  import sys
7
38
  import time
8
39
  import traceback
9
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
40
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
10
41
 
11
42
  import haiku as hk
12
43
  import jax
@@ -38,8 +69,7 @@ try:
38
69
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
39
70
  except Exception:
40
71
  raise_warning('Failed to load the dashboard visualization tool: '
41
- 'please make sure you have installed the required packages.',
42
- 'red')
72
+ 'please make sure you have installed the required packages.', 'red')
43
73
  traceback.print_exc()
44
74
  JaxPlannerDashboard = None
45
75
 
@@ -102,7 +132,7 @@ def _load_config(config, args):
102
132
  comp_kwargs = model_args.get('complement_kwargs', {})
103
133
  compare_name = model_args.get('comparison', 'SigmoidComparison')
104
134
  compare_kwargs = model_args.get('comparison_kwargs', {})
105
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
135
+ sampling_name = model_args.get('sampling', 'SoftRandomSampling')
106
136
  sampling_kwargs = model_args.get('sampling_kwargs', {})
107
137
  rounding_name = model_args.get('rounding', 'SoftRounding')
108
138
  rounding_kwargs = model_args.get('rounding_kwargs', {})
@@ -125,8 +155,7 @@ def _load_config(config, args):
125
155
  initializer = _getattr_any(
126
156
  packages=[initializers, hk.initializers], item=plan_initializer)
127
157
  if initializer is None:
128
- raise_warning(
129
- f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
158
+ raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
130
159
  del plan_kwargs['initializer']
131
160
  else:
132
161
  init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
@@ -143,8 +172,7 @@ def _load_config(config, args):
143
172
  activation = _getattr_any(
144
173
  packages=[jax.nn, jax.numpy], item=plan_activation)
145
174
  if activation is None:
146
- raise_warning(
147
- f'Ignoring invalid activation <{plan_activation}>.', 'red')
175
+ raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
148
176
  del plan_kwargs['activation']
149
177
  else:
150
178
  plan_kwargs['activation'] = activation
@@ -158,12 +186,24 @@ def _load_config(config, args):
158
186
  if planner_optimizer is not None:
159
187
  optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
160
188
  if optimizer is None:
161
- raise_warning(
162
- f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
189
+ raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
163
190
  del planner_args['optimizer']
164
191
  else:
165
192
  planner_args['optimizer'] = optimizer
166
-
193
+
194
+ # pgpe optimizer
195
+ pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
196
+ pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
197
+ if pgpe_method is not None:
198
+ if 'optimizer' in pgpe_kwargs:
199
+ pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
200
+ if pgpe_optimizer is None:
201
+ raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
202
+ del pgpe_kwargs['optimizer']
203
+ else:
204
+ pgpe_kwargs['optimizer'] = pgpe_optimizer
205
+ planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
206
+
167
207
  # optimize call RNG key
168
208
  planner_key = train_args.get('key', None)
169
209
  if planner_key is not None:
@@ -241,48 +281,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
241
281
  pvars_cast = set()
242
282
  for (var, values) in self.init_values.items():
243
283
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
244
- if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
284
+ if not np.issubdtype(np.result_type(values), np.floating):
245
285
  pvars_cast.add(var)
246
286
  if pvars_cast:
247
287
  raise_warning(f'JAX gradient compiler requires that initial values '
248
288
  f'of p-variables {pvars_cast} be cast to float.')
249
289
 
250
290
  # overwrite basic operations with fuzzy ones
251
- self.RELATIONAL_OPS = {
252
- '>=': logic.greater_equal,
253
- '<=': logic.less_equal,
254
- '<': logic.less,
255
- '>': logic.greater,
256
- '==': logic.equal,
257
- '~=': logic.not_equal
258
- }
259
- self.LOGICAL_NOT = logic.logical_not
260
- self.LOGICAL_OPS = {
261
- '^': logic.logical_and,
262
- '&': logic.logical_and,
263
- '|': logic.logical_or,
264
- '~': logic.xor,
265
- '=>': logic.implies,
266
- '<=>': logic.equiv
267
- }
268
- self.AGGREGATION_OPS['forall'] = logic.forall
269
- self.AGGREGATION_OPS['exists'] = logic.exists
270
- self.AGGREGATION_OPS['argmin'] = logic.argmin
271
- self.AGGREGATION_OPS['argmax'] = logic.argmax
272
- self.KNOWN_UNARY['sgn'] = logic.sgn
273
- self.KNOWN_UNARY['floor'] = logic.floor
274
- self.KNOWN_UNARY['ceil'] = logic.ceil
275
- self.KNOWN_UNARY['round'] = logic.round
276
- self.KNOWN_UNARY['sqrt'] = logic.sqrt
277
- self.KNOWN_BINARY['div'] = logic.div
278
- self.KNOWN_BINARY['mod'] = logic.mod
279
- self.KNOWN_BINARY['fmod'] = logic.mod
280
- self.IF_HELPER = logic.control_if
281
- self.SWITCH_HELPER = logic.control_switch
282
- self.BERNOULLI_HELPER = logic.bernoulli
283
- self.DISCRETE_HELPER = logic.discrete
284
- self.POISSON_HELPER = logic.poisson
285
- self.GEOMETRIC_HELPER = logic.geometric
291
+ self.OPS = logic.get_operator_dicts()
286
292
 
287
293
  def _jax_stop_grad(self, jax_expr):
288
294
  def _jax_wrapped_stop_grad(x, params, key):
@@ -469,16 +475,16 @@ class JaxStraightLinePlan(JaxPlan):
469
475
  bounds = '\n '.join(
470
476
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
471
477
  return (f'policy hyper-parameters:\n'
472
- f' initializer ={self._initializer_base}\n'
473
- f'constraint-sat strategy (simple):\n'
474
- f' parsed_action_bounds =\n {bounds}\n'
475
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
476
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
477
- f' wrap_non_bool ={self._wrap_non_bool}\n'
478
- f'constraint-sat strategy (complex):\n'
479
- f' wrap_softmax ={self._wrap_softmax}\n'
480
- f' use_new_projection ={self._use_new_projection}\n'
481
- f' max_projection_iters ={self._max_constraint_iter}')
478
+ f' initializer={self._initializer_base}\n'
479
+ f' constraint-sat strategy (simple):\n'
480
+ f' parsed_action_bounds =\n {bounds}\n'
481
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
482
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
483
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
484
+ f' constraint-sat strategy (complex):\n'
485
+ f' wrap_softmax ={self._wrap_softmax}\n'
486
+ f' use_new_projection ={self._use_new_projection}\n'
487
+ f' max_projection_iters={self._max_constraint_iter}\n')
482
488
 
483
489
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
484
490
  _bounds: Bounds,
@@ -531,7 +537,7 @@ class JaxStraightLinePlan(JaxPlan):
531
537
  def _jax_non_bool_param_to_action(var, param, hyperparams):
532
538
  if wrap_non_bool:
533
539
  lower, upper = bounds_safe[var]
534
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
540
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
535
541
  for mask in cond_lists[var]]
536
542
  action = (
537
543
  mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
@@ -616,7 +622,7 @@ class JaxStraightLinePlan(JaxPlan):
616
622
  action = _jax_non_bool_param_to_action(var, action, hyperparams)
617
623
  action = jnp.clip(action, *bounds[var])
618
624
  if ranges[var] == 'int':
619
- action = jnp.round(action).astype(compiled.INT)
625
+ action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
620
626
  actions[var] = action
621
627
  return actions
622
628
 
@@ -856,15 +862,16 @@ class JaxDeepReactivePolicy(JaxPlan):
856
862
  bounds = '\n '.join(
857
863
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
858
864
  return (f'policy hyper-parameters:\n'
859
- f' topology ={self._topology}\n'
860
- f' activation_fn ={self._activations[0].__name__}\n'
861
- f' initializer ={type(self._initializer_base).__name__}\n'
862
- f' apply_input_norm ={self._normalize}\n'
863
- f' input_norm_layerwise={self._normalize_per_layer}\n'
864
- f' input_norm_args ={self._normalizer_kwargs}\n'
865
- f'constraint-sat strategy:\n'
866
- f' parsed_action_bounds=\n {bounds}\n'
867
- f' wrap_non_bool ={self._wrap_non_bool}')
865
+ f' topology ={self._topology}\n'
866
+ f' activation_fn={self._activations[0].__name__}\n'
867
+ f' initializer ={type(self._initializer_base).__name__}\n'
868
+ f' input norm:\n'
869
+ f' apply_input_norm ={self._normalize}\n'
870
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
871
+ f' input_norm_args ={self._normalizer_kwargs}\n'
872
+ f' constraint-sat strategy:\n'
873
+ f' parsed_action_bounds=\n {bounds}\n'
874
+ f' wrap_non_bool ={self._wrap_non_bool}\n')
868
875
 
869
876
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
870
877
  _bounds: Bounds,
@@ -916,12 +923,11 @@ class JaxDeepReactivePolicy(JaxPlan):
916
923
  non_bool_dims = 0
917
924
  for (var, values) in observed_vars.items():
918
925
  if ranges[var] != 'bool':
919
- value_size = np.atleast_1d(values).size
926
+ value_size = np.size(values)
920
927
  if normalize_per_layer and value_size == 1:
921
928
  raise_warning(
922
929
  f'Cannot apply layer norm to state-fluent <{var}> '
923
- f'of size 1: setting normalize_per_layer = False.',
924
- 'red')
930
+ f'of size 1: setting normalize_per_layer = False.', 'red')
925
931
  normalize_per_layer = False
926
932
  non_bool_dims += value_size
927
933
  if not normalize_per_layer and non_bool_dims == 1:
@@ -945,9 +951,11 @@ class JaxDeepReactivePolicy(JaxPlan):
945
951
  else:
946
952
  if normalize and normalize_per_layer:
947
953
  normalizer = hk.LayerNorm(
948
- axis=-1, param_axis=-1,
954
+ axis=-1,
955
+ param_axis=-1,
949
956
  name=f'input_norm_{input_names[var]}',
950
- **self._normalizer_kwargs)
957
+ **self._normalizer_kwargs
958
+ )
951
959
  state = normalizer(state)
952
960
  states_non_bool.append(state)
953
961
  non_bool_dims += state.size
@@ -956,8 +964,11 @@ class JaxDeepReactivePolicy(JaxPlan):
956
964
  # optionally perform layer normalization on the non-bool inputs
957
965
  if normalize and not normalize_per_layer and non_bool_dims:
958
966
  normalizer = hk.LayerNorm(
959
- axis=-1, param_axis=-1, name='input_norm',
960
- **self._normalizer_kwargs)
967
+ axis=-1,
968
+ param_axis=-1,
969
+ name='input_norm',
970
+ **self._normalizer_kwargs
971
+ )
961
972
  normalized = normalizer(state[:non_bool_dims])
962
973
  state = state.at[:non_bool_dims].set(normalized)
963
974
  return state
@@ -976,7 +987,8 @@ class JaxDeepReactivePolicy(JaxPlan):
976
987
  actions = {}
977
988
  for (var, size) in layer_sizes.items():
978
989
  linear = hk.Linear(size, name=layer_names[var], w_init=init)
979
- reshape = hk.Reshape(output_shape=shapes[var], preserve_dims=-1,
990
+ reshape = hk.Reshape(output_shape=shapes[var],
991
+ preserve_dims=-1,
980
992
  name=f'reshape_{layer_names[var]}')
981
993
  output = reshape(linear(hidden))
982
994
  if not shapes[var]:
@@ -989,7 +1001,7 @@ class JaxDeepReactivePolicy(JaxPlan):
989
1001
  else:
990
1002
  if wrap_non_bool:
991
1003
  lower, upper = bounds_safe[var]
992
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
1004
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
993
1005
  for mask in cond_lists[var]]
994
1006
  action = (
995
1007
  mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
@@ -1003,8 +1015,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1003
1015
 
1004
1016
  # for constraint satisfaction wrap bool actions with softmax
1005
1017
  if use_constraint_satisfaction:
1006
- linear = hk.Linear(
1007
- bool_action_count, name='output_bool', w_init=init)
1018
+ linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
1008
1019
  output = jax.nn.softmax(linear(hidden))
1009
1020
  actions[bool_key] = output
1010
1021
 
@@ -1042,8 +1053,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1042
1053
 
1043
1054
  # test action prediction
1044
1055
  def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
1045
- actions = _jax_wrapped_drp_predict_train(
1046
- key, params, hyperparams, step, subs)
1056
+ actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
1047
1057
  new_actions = {}
1048
1058
  for (var, action) in actions.items():
1049
1059
  prange = ranges[var]
@@ -1051,7 +1061,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1051
1061
  new_action = action > 0.5
1052
1062
  elif prange == 'int':
1053
1063
  action = jnp.clip(action, *bounds[var])
1054
- new_action = jnp.round(action).astype(compiled.INT)
1064
+ new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
1055
1065
  else:
1056
1066
  new_action = jnp.clip(action, *bounds[var])
1057
1067
  new_actions[var] = new_action
@@ -1090,10 +1100,11 @@ class JaxDeepReactivePolicy(JaxPlan):
1090
1100
 
1091
1101
 
1092
1102
  # ***********************************************************************
1093
- # ALL VERSIONS OF JAX PLANNER
1103
+ # SUPPORTING FUNCTIONS
1094
1104
  #
1095
- # - simple gradient descent based planner
1096
- # - more stable but slower line search based planner
1105
+ # - smoothed mean calculation
1106
+ # - planner status
1107
+ # - stopping criteria
1097
1108
  #
1098
1109
  # ***********************************************************************
1099
1110
 
@@ -1167,6 +1178,329 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1167
1178
  return f'No improvement for {self.patience} iterations'
1168
1179
 
1169
1180
 
1181
+ # ***********************************************************************
1182
+ # PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
1183
+ #
1184
+ # - simple Gaussian PGPE
1185
+ #
1186
+ # ***********************************************************************
1187
+
1188
+
1189
+ class PGPE:
1190
+ """Base class for all PGPE strategies."""
1191
+
1192
+ def __init__(self) -> None:
1193
+ self._initializer = None
1194
+ self._update = None
1195
+
1196
+ @property
1197
+ def initialize(self):
1198
+ return self._initializer
1199
+
1200
+ @property
1201
+ def update(self):
1202
+ return self._update
1203
+
1204
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1205
+ raise NotImplementedError
1206
+
1207
+
1208
+ class GaussianPGPE(PGPE):
1209
+ '''PGPE with a Gaussian parameter distribution.'''
1210
+
1211
+ def __init__(self, batch_size: int=1,
1212
+ init_sigma: float=1.0,
1213
+ sigma_range: Tuple[float, float]=(1e-5, 1e5),
1214
+ scale_reward: bool=True,
1215
+ super_symmetric: bool=True,
1216
+ super_symmetric_accurate: bool=True,
1217
+ optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1218
+ optimizer_kwargs_mu: Optional[Kwargs]=None,
1219
+ optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1220
+ '''Creates a new Gaussian PGPE planner.
1221
+
1222
+ :param batch_size: how many policy parameters to sample per optimization step
1223
+ :param init_sigma: initial standard deviation of Gaussian
1224
+ :param sigma_range: bounds to constrain standard deviation
1225
+ :param scale_reward: whether to apply reward scaling as in the paper
1226
+ :param super_symmetric: whether to use super-symmetric sampling as in the paper
1227
+ :param super_symmetric_accurate: whether to use the accurate formula for super-
1228
+ symmetric sampling or the simplified but biased formula
1229
+ :param optimizer: a factory for an optax SGD algorithm
1230
+ :param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
1231
+ factory for the mean optimizer
1232
+ :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1233
+ factory for the standard deviation optimizer
1234
+ '''
1235
+ super().__init__()
1236
+
1237
+ self.batch_size = batch_size
1238
+ self.init_sigma = init_sigma
1239
+ self.sigma_range = sigma_range
1240
+ self.scale_reward = scale_reward
1241
+ self.super_symmetric = super_symmetric
1242
+ self.super_symmetric_accurate = super_symmetric_accurate
1243
+
1244
+ # set optimizers
1245
+ if optimizer_kwargs_mu is None:
1246
+ optimizer_kwargs_mu = {'learning_rate': 0.1}
1247
+ self.optimizer_kwargs_mu = optimizer_kwargs_mu
1248
+ if optimizer_kwargs_sigma is None:
1249
+ optimizer_kwargs_sigma = {'learning_rate': 0.1}
1250
+ self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1251
+ self.optimizer_name = optimizer
1252
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1253
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1254
+ self.optimizers = (mu_optimizer, sigma_optimizer)
1255
+
1256
+ def __str__(self) -> str:
1257
+ return (f'PGPE hyper-parameters:\n'
1258
+ f' method ={self.__class__.__name__}\n'
1259
+ f' batch_size ={self.batch_size}\n'
1260
+ f' init_sigma ={self.init_sigma}\n'
1261
+ f' sigma_range ={self.sigma_range}\n'
1262
+ f' scale_reward ={self.scale_reward}\n'
1263
+ f' super_symmetric={self.super_symmetric}\n'
1264
+ f' accurate ={self.super_symmetric_accurate}\n'
1265
+ f' optimizer ={self.optimizer_name}\n'
1266
+ f' optimizer_kwargs:\n'
1267
+ f' mu ={self.optimizer_kwargs_mu}\n'
1268
+ f' sigma={self.optimizer_kwargs_sigma}\n'
1269
+ )
1270
+
1271
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1272
+ MIN_NORM = 1e-5
1273
+ sigma0 = self.init_sigma
1274
+ sigma_range = self.sigma_range
1275
+ scale_reward = self.scale_reward
1276
+ super_symmetric = self.super_symmetric
1277
+ super_symmetric_accurate = self.super_symmetric_accurate
1278
+ batch_size = self.batch_size
1279
+ optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1280
+
1281
+ # initializer
1282
+ def _jax_wrapped_pgpe_init(key, policy_params):
1283
+ mu = policy_params
1284
+ sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
1285
+ pgpe_params = (mu, sigma)
1286
+ pgpe_opt_state = tuple(opt.init(param)
1287
+ for (opt, param) in zip(optimizers, pgpe_params))
1288
+ return pgpe_params, pgpe_opt_state
1289
+
1290
+ self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1291
+
1292
+ # parameter sampling functions
1293
+ def _jax_wrapped_mu_noise(key, sigma):
1294
+ return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1295
+
1296
+ def _jax_wrapped_epsilon_star(sigma, epsilon):
1297
+ c1, c2, c3 = -0.06655, -0.9706, 0.124
1298
+ phi = 0.67449 * sigma
1299
+ a = (sigma - jnp.abs(epsilon)) / sigma
1300
+ if super_symmetric_accurate:
1301
+ aa = jnp.abs(a)
1302
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1303
+ a <= 0,
1304
+ jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1305
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1306
+ )
1307
+ else:
1308
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1309
+ return epsilon_star
1310
+
1311
+ def _jax_wrapped_sample_params(key, mu, sigma):
1312
+ keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1313
+ keys_pytree = jax.tree_util.tree_unflatten(
1314
+ treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1315
+ epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1316
+ p1 = jax.tree_map(jnp.add, mu, epsilon)
1317
+ p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1318
+ if super_symmetric:
1319
+ epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1320
+ p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1321
+ p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1322
+ else:
1323
+ epsilon_star, p3, p4 = epsilon, p1, p2
1324
+ return (p1, p2, p3, p4), (epsilon, epsilon_star)
1325
+
1326
+ # policy gradient update functions
1327
+ def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1328
+ if super_symmetric:
1329
+ if scale_reward:
1330
+ scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1331
+ scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1332
+ else:
1333
+ scale1 = scale2 = 1.0
1334
+ r_mu1 = (r1 - r2) / (2 * scale1)
1335
+ r_mu2 = (r3 - r4) / (2 * scale2)
1336
+ grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1337
+ else:
1338
+ if scale_reward:
1339
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1340
+ else:
1341
+ scale = 1.0
1342
+ r_mu = (r1 - r2) / (2 * scale)
1343
+ grad = -r_mu * epsilon
1344
+ return grad
1345
+
1346
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1347
+ if super_symmetric:
1348
+ mask = r1 + r2 >= r3 + r4
1349
+ epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1350
+ s = epsilon_tau * epsilon_tau / sigma - sigma
1351
+ if scale_reward:
1352
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1353
+ else:
1354
+ scale = 1.0
1355
+ r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1356
+ else:
1357
+ s = epsilon * epsilon / sigma - sigma
1358
+ if scale_reward:
1359
+ scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1360
+ else:
1361
+ scale = 1.0
1362
+ r_sigma = (r1 + r2) / (2 * scale)
1363
+ grad = -r_sigma * s
1364
+ return grad
1365
+
1366
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1367
+ policy_hyperparams, subs, model_params):
1368
+ key, subkey = random.split(key)
1369
+ (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1370
+ key, mu, sigma)
1371
+ r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1372
+ r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
1373
+ r_max = jnp.maximum(r_max, r1)
1374
+ r_max = jnp.maximum(r_max, r2)
1375
+ if super_symmetric:
1376
+ r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
1377
+ r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
1378
+ r_max = jnp.maximum(r_max, r3)
1379
+ r_max = jnp.maximum(r_max, r4)
1380
+ else:
1381
+ r3, r4 = r1, r2
1382
+ grad_mu = jax.tree_map(
1383
+ partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1384
+ epsilon, epsilon_star
1385
+ )
1386
+ grad_sigma = jax.tree_map(
1387
+ partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1388
+ epsilon, epsilon_star, sigma
1389
+ )
1390
+ return grad_mu, grad_sigma, r_max
1391
+
1392
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1393
+ policy_hyperparams, subs, model_params):
1394
+ mu, sigma = pgpe_params
1395
+ if batch_size == 1:
1396
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1397
+ key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1398
+ else:
1399
+ keys = random.split(key, num=batch_size)
1400
+ mu_grads, sigma_grads, r_maxs = jax.vmap(
1401
+ _jax_wrapped_pgpe_grad,
1402
+ in_axes=(0, None, None, None, None, None, None)
1403
+ )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1404
+ mu_grad, sigma_grad = jax.tree_map(
1405
+ partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1406
+ new_r_max = jnp.max(r_maxs)
1407
+ return mu_grad, sigma_grad, new_r_max
1408
+
1409
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1410
+ policy_hyperparams, subs, model_params,
1411
+ pgpe_opt_state):
1412
+ mu, sigma = pgpe_params
1413
+ mu_state, sigma_state = pgpe_opt_state
1414
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1415
+ key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1416
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1417
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1418
+ sigma_grad, sigma_state, params=sigma)
1419
+ new_mu = optax.apply_updates(mu, mu_updates)
1420
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1421
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1422
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1423
+ new_pgpe_params = (new_mu, new_sigma)
1424
+ new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1425
+ policy_params = new_mu
1426
+ return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1427
+
1428
+ self._update = jax.jit(_jax_wrapped_pgpe_update)
1429
+
1430
+
1431
+ # ***********************************************************************
1432
+ # ALL VERSIONS OF RISK FUNCTIONS
1433
+ #
1434
+ # Based on the original paper "A Distributional Framework for Risk-Sensitive
1435
+ # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
1436
+ #
1437
+ # Original risk functions:
1438
+ # - entropic utility
1439
+ # - mean-variance
1440
+ # - mean-semideviation
1441
+ # - conditional value at risk with straight-through gradient trick
1442
+ #
1443
+ # ***********************************************************************
1444
+
1445
+
1446
+ @jax.jit
1447
+ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1448
+ return (-1.0 / beta) * jax.scipy.special.logsumexp(
1449
+ -beta * returns, b=1.0 / returns.size)
1450
+
1451
+
1452
+ @jax.jit
1453
+ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1454
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
1455
+
1456
+
1457
+ @jax.jit
1458
+ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
1459
+ return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
1460
+
1461
+
1462
+ @jax.jit
1463
+ def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
1464
+ mu = jnp.mean(returns)
1465
+ msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
1466
+ return mu - 0.5 * beta * msd
1467
+
1468
+
1469
+ @jax.jit
1470
+ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1471
+ mu = jnp.mean(returns)
1472
+ msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
1473
+ return mu - 0.5 * beta * msv
1474
+
1475
+
1476
+ @jax.jit
1477
+ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1478
+ var = jnp.percentile(returns, q=100 * alpha)
1479
+ mask = returns <= var
1480
+ weights = mask / jnp.maximum(1, jnp.sum(mask))
1481
+ return jnp.sum(returns * weights)
1482
+
1483
+
1484
+ UTILITY_LOOKUP = {
1485
+ 'mean': jnp.mean,
1486
+ 'mean_var': mean_variance_utility,
1487
+ 'mean_std': mean_deviation_utility,
1488
+ 'mean_semivar': mean_semivariance_utility,
1489
+ 'mean_semidev': mean_semideviation_utility,
1490
+ 'entropic': entropic_utility,
1491
+ 'exponential': entropic_utility,
1492
+ 'cvar': cvar_utility
1493
+ }
1494
+
1495
+
1496
+ # ***********************************************************************
1497
+ # ALL VERSIONS OF JAX PLANNER
1498
+ #
1499
+ # - simple gradient descent based planner
1500
+ #
1501
+ # ***********************************************************************
1502
+
1503
+
1170
1504
  class JaxBackpropPlanner:
1171
1505
  '''A class for optimizing an action sequence in the given RDDL MDP using
1172
1506
  gradient descent.'''
@@ -1183,6 +1517,7 @@ class JaxBackpropPlanner:
1183
1517
  clip_grad: Optional[float]=None,
1184
1518
  line_search_kwargs: Optional[Kwargs]=None,
1185
1519
  noise_kwargs: Optional[Kwargs]=None,
1520
+ pgpe: Optional[PGPE]=GaussianPGPE(),
1186
1521
  logic: Logic=FuzzyLogic(),
1187
1522
  use_symlog_reward: bool=False,
1188
1523
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
@@ -1213,14 +1548,14 @@ class JaxBackpropPlanner:
1213
1548
  :param line_search_kwargs: parameters to pass to optional line search
1214
1549
  method to scale learning rate
1215
1550
  :param noise_kwargs: parameters of optional gradient noise
1551
+ :param pgpe: optional policy gradient to run alongside the planner
1216
1552
  :param logic: a subclass of Logic for mapping exact mathematical
1217
1553
  operations to their differentiable counterparts
1218
1554
  :param use_symlog_reward: whether to use the symlog transform on the
1219
1555
  reward as a form of normalization
1220
1556
  :param utility: how to aggregate return observations to compute utility
1221
1557
  of a policy or plan; must be either a function mapping jax array to a
1222
- scalar, or a a string identifying the utility function by name
1223
- ("mean", "mean_var", "entropic", or "cvar" are currently supported)
1558
+ scalar, or a a string identifying the utility function by name
1224
1559
  :param utility_kwargs: additional keyword arguments to pass hyper-
1225
1560
  parameters to the utility function call
1226
1561
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
@@ -1251,6 +1586,8 @@ class JaxBackpropPlanner:
1251
1586
  self.clip_grad = clip_grad
1252
1587
  self.line_search_kwargs = line_search_kwargs
1253
1588
  self.noise_kwargs = noise_kwargs
1589
+ self.pgpe = pgpe
1590
+ self.use_pgpe = pgpe is not None
1254
1591
 
1255
1592
  # set optimizer
1256
1593
  try:
@@ -1276,18 +1613,11 @@ class JaxBackpropPlanner:
1276
1613
  # set utility
1277
1614
  if isinstance(utility, str):
1278
1615
  utility = utility.lower()
1279
- if utility == 'mean':
1280
- utility_fn = jnp.mean
1281
- elif utility == 'mean_var':
1282
- utility_fn = mean_variance_utility
1283
- elif utility == 'entropic':
1284
- utility_fn = entropic_utility
1285
- elif utility == 'cvar':
1286
- utility_fn = cvar_utility
1287
- else:
1616
+ utility_fn = UTILITY_LOOKUP.get(utility, None)
1617
+ if utility_fn is None:
1288
1618
  raise RDDLNotImplementedError(
1289
- f'Utility function <{utility}> is not supported: '
1290
- 'must be one of ["mean", "mean_var", "entropic", "cvar"].')
1619
+ f'Utility <{utility}> is not supported, '
1620
+ f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1291
1621
  else:
1292
1622
  utility_fn = utility
1293
1623
  self.utility = utility_fn
@@ -1355,24 +1685,25 @@ r"""
1355
1685
  f' line_search_kwargs={self.line_search_kwargs}\n'
1356
1686
  f' noise_kwargs ={self.noise_kwargs}\n'
1357
1687
  f' batch_size_train ={self.batch_size_train}\n'
1358
- f' batch_size_test ={self.batch_size_test}')
1359
- result += '\n' + str(self.plan)
1360
- result += '\n' + str(self.logic)
1688
+ f' batch_size_test ={self.batch_size_test}\n')
1689
+ result += str(self.plan)
1690
+ if self.use_pgpe:
1691
+ result += str(self.pgpe)
1692
+ result += str(self.logic)
1361
1693
 
1362
1694
  # print model relaxation information
1363
- if not self.compiled.model_params:
1364
- return result
1365
- result += '\n' + ('Some RDDL operations are non-differentiable '
1366
- 'and will be approximated as follows:' + '\n')
1367
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1368
- for info in self.compiled.model_parameter_info().values():
1369
- rddl_op = info['rddl_op']
1370
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1371
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1372
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1373
- result += (f' {rddl_op}:\n'
1374
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1375
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1695
+ if self.compiled.model_params:
1696
+ result += ('Some RDDL operations are non-differentiable '
1697
+ 'and will be approximated as follows:' + '\n')
1698
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1699
+ for info in self.compiled.model_parameter_info().values():
1700
+ rddl_op = info['rddl_op']
1701
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1702
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1703
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1704
+ result += (f' {rddl_op}:\n'
1705
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1706
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1376
1707
  return result
1377
1708
 
1378
1709
  def summarize_hyperparameters(self) -> None:
@@ -1438,6 +1769,15 @@ r"""
1438
1769
  # optimization
1439
1770
  self.update = self._jax_update(train_loss)
1440
1771
  self.check_zero_grad = self._jax_check_zero_gradients()
1772
+
1773
+ # pgpe option
1774
+ if self.use_pgpe:
1775
+ loss_fn = self._jax_loss(rollouts=test_rollouts)
1776
+ self.pgpe.compile(
1777
+ loss_fn=loss_fn,
1778
+ projection=self.plan.projection,
1779
+ real_dtype=self.test_compiled.REAL
1780
+ )
1441
1781
 
1442
1782
  def _jax_return(self, use_symlog):
1443
1783
  gamma = self.rddl.discount
@@ -1547,7 +1887,7 @@ r"""
1547
1887
  f'{set(self.test_compiled.init_values.keys())}.')
1548
1888
  value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
1549
1889
  train_value = np.repeat(value, repeats=n_train, axis=0)
1550
- train_value = train_value.astype(self.compiled.REAL)
1890
+ train_value = np.asarray(train_value, dtype=self.compiled.REAL)
1551
1891
  init_train[name] = train_value
1552
1892
  init_test[name] = np.repeat(value, repeats=n_test, axis=0)
1553
1893
 
@@ -1646,7 +1986,7 @@ r"""
1646
1986
  return grad
1647
1987
 
1648
1988
  return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1649
-
1989
+
1650
1990
  # ===========================================================================
1651
1991
  # OPTIMIZE API
1652
1992
  # ===========================================================================
@@ -1819,7 +2159,17 @@ r"""
1819
2159
  policy_params = guess
1820
2160
  opt_state = self.optimizer.init(policy_params)
1821
2161
  opt_aux = {}
1822
-
2162
+
2163
+ # initialize pgpe parameters
2164
+ if self.use_pgpe:
2165
+ pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
2166
+ rolling_pgpe_loss = RollingMean(test_rolling_window)
2167
+ else:
2168
+ pgpe_params, pgpe_opt_state = None, None
2169
+ rolling_pgpe_loss = None
2170
+ total_pgpe_it = 0
2171
+ r_max = -jnp.inf
2172
+
1823
2173
  # ======================================================================
1824
2174
  # INITIALIZATION OF RUNNING STATISTICS
1825
2175
  # ======================================================================
@@ -1847,7 +2197,9 @@ r"""
1847
2197
 
1848
2198
  iters = range(epochs)
1849
2199
  if print_progress:
1850
- iters = tqdm(iters, total=100, position=tqdm_position)
2200
+ iters = tqdm(iters, total=100,
2201
+ bar_format='{l_bar}{bar}| {elapsed} {postfix}',
2202
+ position=tqdm_position)
1851
2203
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1852
2204
 
1853
2205
  for it in iters:
@@ -1860,17 +2212,47 @@ r"""
1860
2212
 
1861
2213
  # update the parameters of the plan
1862
2214
  key, subkey = random.split(key)
1863
- (policy_params, converged, opt_state, opt_aux,
1864
- train_loss, train_log, model_params) = \
1865
- self.update(subkey, policy_params, policy_hyperparams,
1866
- train_subs, model_params, opt_state, opt_aux)
1867
-
2215
+ (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2216
+ model_params) = self.update(subkey, policy_params, policy_hyperparams,
2217
+ train_subs, model_params, opt_state, opt_aux)
2218
+ test_loss, (test_log, model_params_test) = self.test_loss(
2219
+ subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2220
+ test_loss_smooth = rolling_test_loss.update(test_loss)
2221
+
2222
+ # pgpe update of the plan
2223
+ pgpe_improve = False
2224
+ if self.use_pgpe:
2225
+ key, subkey = random.split(key)
2226
+ pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2227
+ self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2228
+ test_subs, model_params, pgpe_opt_state)
2229
+ pgpe_loss, _ = self.test_loss(
2230
+ subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2231
+ pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2232
+ pgpe_return = -pgpe_loss_smooth
2233
+
2234
+ # replace with PGPE if it reaches a new minimum or train loss invalid
2235
+ if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2236
+ policy_params = pgpe_param
2237
+ test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2238
+ converged = pgpe_converged
2239
+ pgpe_improve = True
2240
+ total_pgpe_it += 1
2241
+ else:
2242
+ pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2243
+
2244
+ # evaluate test losses and record best plan so far
2245
+ if test_loss_smooth < best_loss:
2246
+ best_params, best_loss, best_grad = \
2247
+ policy_params, test_loss_smooth, train_log['grad']
2248
+ last_iter_improve = it
2249
+
1868
2250
  # ==================================================================
1869
2251
  # STATUS CHECKS AND LOGGING
1870
2252
  # ==================================================================
1871
2253
 
1872
2254
  # no progress
1873
- if self.check_zero_grad(train_log['grad']):
2255
+ if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
1874
2256
  status = JaxPlannerStatus.NO_PROGRESS
1875
2257
 
1876
2258
  # constraint satisfaction problem
@@ -1882,21 +2264,14 @@ r"""
1882
2264
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1883
2265
 
1884
2266
  # numerical error
1885
- if not np.isfinite(train_loss):
1886
- raise_warning(
1887
- f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
2267
+ if self.use_pgpe:
2268
+ invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
2269
+ else:
2270
+ invalid_loss = not np.isfinite(train_loss)
2271
+ if invalid_loss:
2272
+ raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
1888
2273
  status = JaxPlannerStatus.INVALID_GRADIENT
1889
2274
 
1890
- # evaluate test losses and record best plan so far
1891
- test_loss, (log, model_params_test) = self.test_loss(
1892
- subkey, policy_params, policy_hyperparams,
1893
- test_subs, model_params_test)
1894
- test_loss = rolling_test_loss.update(test_loss)
1895
- if test_loss < best_loss:
1896
- best_params, best_loss, best_grad = \
1897
- policy_params, test_loss, train_log['grad']
1898
- last_iter_improve = it
1899
-
1900
2275
  # reached computation budget
1901
2276
  elapsed = time.time() - start_time - elapsed_outside_loop
1902
2277
  if elapsed >= train_seconds:
@@ -1905,16 +2280,20 @@ r"""
1905
2280
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1906
2281
 
1907
2282
  # build a callback
1908
- progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
2283
+ progress_percent = 100 * min(
2284
+ 1, max(0, elapsed / train_seconds, it / (epochs - 1)))
1909
2285
  callback = {
1910
2286
  'status': status,
1911
2287
  'iteration': it,
1912
2288
  'train_return':-train_loss,
1913
- 'test_return':-test_loss,
2289
+ 'test_return':-test_loss_smooth,
1914
2290
  'best_return':-best_loss,
2291
+ 'pgpe_return': pgpe_return,
1915
2292
  'params': policy_params,
1916
2293
  'best_params': best_params,
2294
+ 'pgpe_params': pgpe_params,
1917
2295
  'last_iteration_improved': last_iter_improve,
2296
+ 'pgpe_improved': pgpe_improve,
1918
2297
  'grad': train_log['grad'],
1919
2298
  'best_grad': best_grad,
1920
2299
  'updates': train_log['updates'],
@@ -1923,9 +2302,9 @@ r"""
1923
2302
  'model_params': model_params,
1924
2303
  'progress': progress_percent,
1925
2304
  'train_log': train_log,
1926
- **log
2305
+ **test_log
1927
2306
  }
1928
-
2307
+
1929
2308
  # stopping condition reached
1930
2309
  if stopping_rule is not None and stopping_rule.monitor(callback):
1931
2310
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
@@ -1934,10 +2313,12 @@ r"""
1934
2313
  if print_progress:
1935
2314
  iters.n = progress_percent
1936
2315
  iters.set_description(
1937
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1938
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1939
- f'{status.value} status'
2316
+ f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2317
+ f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2318
+ f'{status.value} status / {total_pgpe_it:6} pgpe',
2319
+ refresh=False
1940
2320
  )
2321
+ iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
1941
2322
 
1942
2323
  # dash-board
1943
2324
  if dashboard is not None:
@@ -1955,7 +2336,7 @@ r"""
1955
2336
  # ======================================================================
1956
2337
  # POST-PROCESSING AND CLEANUP
1957
2338
  # ======================================================================
1958
-
2339
+
1959
2340
  # release resources
1960
2341
  if print_progress:
1961
2342
  iters.close()
@@ -1967,7 +2348,7 @@ r"""
1967
2348
  messages.update(JaxRDDLCompiler.get_error_messages(error_code))
1968
2349
  if messages:
1969
2350
  messages = '\n'.join(messages)
1970
- raise_warning('The JAX compiler encountered the following '
2351
+ raise_warning('JAX compiler encountered the following '
1971
2352
  'error(s) in the original RDDL formulation '
1972
2353
  f'during test evaluation:\n{messages}', 'red')
1973
2354
 
@@ -1975,14 +2356,14 @@ r"""
1975
2356
  if print_summary:
1976
2357
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1977
2358
  diagnosis = self._perform_diagnosis(
1978
- last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
2359
+ last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
1979
2360
  print(f'summary of optimization:\n'
1980
- f' status_code ={status}\n'
1981
- f' time_elapsed ={elapsed}\n'
2361
+ f' status ={status}\n'
2362
+ f' time ={elapsed:.3f} sec.\n'
1982
2363
  f' iterations ={it}\n'
1983
- f' best_objective={-best_loss}\n'
1984
- f' best_grad_norm={grad_norm}\n'
1985
- f' diagnosis: {diagnosis}\n')
2364
+ f' best objective={-best_loss:.6f}\n'
2365
+ f' best grad norm={grad_norm}\n'
2366
+ f'diagnosis: {diagnosis}\n')
1986
2367
 
1987
2368
  def _perform_diagnosis(self, last_iter_improve,
1988
2369
  train_return, test_return, best_return, grad_norm):
@@ -2002,23 +2383,24 @@ r"""
2002
2383
  if last_iter_improve <= 1:
2003
2384
  if grad_is_zero:
2004
2385
  return termcolor.colored(
2005
- '[FAILURE] no progress was made, '
2006
- f'and max grad norm {max_grad_norm:.6f} is zero: '
2386
+ '[FAILURE] no progress was made '
2387
+ f'and max grad norm {max_grad_norm:.6f} was zero: '
2007
2388
  'solver likely stuck in a plateau.', 'red')
2008
2389
  else:
2009
2390
  return termcolor.colored(
2010
- '[FAILURE] no progress was made, '
2011
- f'but max grad norm {max_grad_norm:.6f} is non-zero: '
2012
- 'likely poor learning rate or other hyper-parameter.', 'red')
2391
+ '[FAILURE] no progress was made '
2392
+ f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2393
+ 'learning rate or other hyper-parameters likely suboptimal.',
2394
+ 'red')
2013
2395
 
2014
2396
  # model is likely poor IF:
2015
2397
  # 1. the train and test return disagree
2016
2398
  if not (validation_error < 20):
2017
2399
  return termcolor.colored(
2018
- '[WARNING] progress was made, '
2019
- f'but relative train-test error {validation_error:.6f} is high: '
2020
- 'likely poor model relaxation around the solution, '
2021
- 'or the batch size is too small.', 'yellow')
2400
+ '[WARNING] progress was made '
2401
+ f'but relative train-test error {validation_error:.6f} was high: '
2402
+ 'poor model relaxation around solution or batch size too small.',
2403
+ 'yellow')
2022
2404
 
2023
2405
  # model likely did not converge IF:
2024
2406
  # 1. the max grad relative to the return is high
@@ -2026,15 +2408,15 @@ r"""
2026
2408
  return_to_grad_norm = abs(best_return) / max_grad_norm
2027
2409
  if not (return_to_grad_norm > 1):
2028
2410
  return termcolor.colored(
2029
- '[WARNING] progress was made, '
2030
- f'but max grad norm {max_grad_norm:.6f} is high: '
2031
- 'likely the solution is not locally optimal, '
2032
- 'or the relaxed model is not smooth around the solution, '
2033
- 'or the batch size is too small.', 'yellow')
2411
+ '[WARNING] progress was made '
2412
+ f'but max grad norm {max_grad_norm:.6f} was high: '
2413
+ 'solution locally suboptimal '
2414
+ 'or relaxed model not smooth around solution '
2415
+ 'or batch size too small.', 'yellow')
2034
2416
 
2035
2417
  # likely successful
2036
2418
  return termcolor.colored(
2037
- '[SUCCESS] planner has converged successfully '
2419
+ '[SUCCESS] solver converged successfully '
2038
2420
  '(note: not all potential problems can be ruled out).', 'green')
2039
2421
 
2040
2422
  def get_action(self, key: random.PRNGKey,
@@ -2057,8 +2439,7 @@ r"""
2057
2439
  for (var, values) in subs.items():
2058
2440
 
2059
2441
  # must not be grounded
2060
- if RDDLPlanningModel.FLUENT_SEP in var \
2061
- or RDDLPlanningModel.OBJECT_SEP in var:
2442
+ if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
2062
2443
  raise ValueError(f'State dictionary passed to the JAX policy is '
2063
2444
  f'grounded, since it contains the key <{var}>, '
2064
2445
  f'but a vectorized environment is required: '
@@ -2066,9 +2447,8 @@ r"""
2066
2447
 
2067
2448
  # must be numeric array
2068
2449
  # exception is for POMDPs at 1st epoch when observ-fluents are None
2069
- dtype = np.atleast_1d(values).dtype
2070
- if not np.issubdtype(dtype, np.number) \
2071
- and not np.issubdtype(dtype, np.bool_):
2450
+ dtype = np.result_type(values)
2451
+ if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
2072
2452
  if step == 0 and var in self.rddl.observ_fluents:
2073
2453
  subs[var] = self.test_compiled.init_values[var]
2074
2454
  else:
@@ -2080,40 +2460,7 @@ r"""
2080
2460
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2081
2461
  actions = jax.tree_map(np.asarray, actions)
2082
2462
  return actions
2083
-
2084
-
2085
- # ***********************************************************************
2086
- # ALL VERSIONS OF RISK FUNCTIONS
2087
- #
2088
- # Based on the original paper "A Distributional Framework for Risk-Sensitive
2089
- # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
2090
- #
2091
- # Original risk functions:
2092
- # - entropic utility
2093
- # - mean-variance approximation
2094
- # - conditional value at risk with straight-through gradient trick
2095
- #
2096
- # ***********************************************************************
2097
-
2098
-
2099
- @jax.jit
2100
- def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2101
- return (-1.0 / beta) * jax.scipy.special.logsumexp(
2102
- -beta * returns, b=1.0 / returns.size)
2103
-
2104
-
2105
- @jax.jit
2106
- def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2107
- return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
2108
-
2109
-
2110
- @jax.jit
2111
- def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2112
- var = jnp.percentile(returns, q=100 * alpha)
2113
- mask = returns <= var
2114
- weights = mask / jnp.maximum(1, jnp.sum(mask))
2115
- return jnp.sum(returns * weights)
2116
-
2463
+
2117
2464
 
2118
2465
  # ***********************************************************************
2119
2466
  # ALL VERSIONS OF CONTROLLERS
@@ -2225,8 +2572,7 @@ class JaxOnlineController(BaseAgent):
2225
2572
  self.callback = callback
2226
2573
  params = callback['best_params']
2227
2574
  self.key, subkey = random.split(self.key)
2228
- actions = planner.get_action(
2229
- subkey, params, 0, state, self.eval_hyperparams)
2575
+ actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2230
2576
  if self.warm_start:
2231
2577
  self.guess = planner.plan.guess_next_epoch(params)
2232
2578
  return actions