pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,7 +47,9 @@ import jax.random as random
47
47
  import numpy as np
48
48
  import optax
49
49
  import termcolor
50
- from tqdm import tqdm
50
+ from tqdm import tqdm, TqdmWarning
51
+ import warnings
52
+ warnings.filterwarnings("ignore", category=TqdmWarning)
51
53
 
52
54
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
53
55
  from pyRDDLGym.core.debug.logger import Logger
@@ -69,8 +71,7 @@ try:
69
71
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
70
72
  except Exception:
71
73
  raise_warning('Failed to load the dashboard visualization tool: '
72
- 'please make sure you have installed the required packages.',
73
- 'red')
74
+ 'please make sure you have installed the required packages.', 'red')
74
75
  traceback.print_exc()
75
76
  JaxPlannerDashboard = None
76
77
 
@@ -133,7 +134,7 @@ def _load_config(config, args):
133
134
  comp_kwargs = model_args.get('complement_kwargs', {})
134
135
  compare_name = model_args.get('comparison', 'SigmoidComparison')
135
136
  compare_kwargs = model_args.get('comparison_kwargs', {})
136
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
137
+ sampling_name = model_args.get('sampling', 'SoftRandomSampling')
137
138
  sampling_kwargs = model_args.get('sampling_kwargs', {})
138
139
  rounding_name = model_args.get('rounding', 'SoftRounding')
139
140
  rounding_kwargs = model_args.get('rounding_kwargs', {})
@@ -156,8 +157,7 @@ def _load_config(config, args):
156
157
  initializer = _getattr_any(
157
158
  packages=[initializers, hk.initializers], item=plan_initializer)
158
159
  if initializer is None:
159
- raise_warning(
160
- f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
160
+ raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
161
161
  del plan_kwargs['initializer']
162
162
  else:
163
163
  init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
@@ -174,8 +174,7 @@ def _load_config(config, args):
174
174
  activation = _getattr_any(
175
175
  packages=[jax.nn, jax.numpy], item=plan_activation)
176
176
  if activation is None:
177
- raise_warning(
178
- f'Ignoring invalid activation <{plan_activation}>.', 'red')
177
+ raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
179
178
  del plan_kwargs['activation']
180
179
  else:
181
180
  plan_kwargs['activation'] = activation
@@ -189,8 +188,7 @@ def _load_config(config, args):
189
188
  if planner_optimizer is not None:
190
189
  optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
191
190
  if optimizer is None:
192
- raise_warning(
193
- f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
191
+ raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
194
192
  del planner_args['optimizer']
195
193
  else:
196
194
  planner_args['optimizer'] = optimizer
@@ -285,48 +283,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
285
283
  pvars_cast = set()
286
284
  for (var, values) in self.init_values.items():
287
285
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
288
- if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
286
+ if not np.issubdtype(np.result_type(values), np.floating):
289
287
  pvars_cast.add(var)
290
288
  if pvars_cast:
291
289
  raise_warning(f'JAX gradient compiler requires that initial values '
292
290
  f'of p-variables {pvars_cast} be cast to float.')
293
291
 
294
292
  # overwrite basic operations with fuzzy ones
295
- self.RELATIONAL_OPS = {
296
- '>=': logic.greater_equal,
297
- '<=': logic.less_equal,
298
- '<': logic.less,
299
- '>': logic.greater,
300
- '==': logic.equal,
301
- '~=': logic.not_equal
302
- }
303
- self.LOGICAL_NOT = logic.logical_not
304
- self.LOGICAL_OPS = {
305
- '^': logic.logical_and,
306
- '&': logic.logical_and,
307
- '|': logic.logical_or,
308
- '~': logic.xor,
309
- '=>': logic.implies,
310
- '<=>': logic.equiv
311
- }
312
- self.AGGREGATION_OPS['forall'] = logic.forall
313
- self.AGGREGATION_OPS['exists'] = logic.exists
314
- self.AGGREGATION_OPS['argmin'] = logic.argmin
315
- self.AGGREGATION_OPS['argmax'] = logic.argmax
316
- self.KNOWN_UNARY['sgn'] = logic.sgn
317
- self.KNOWN_UNARY['floor'] = logic.floor
318
- self.KNOWN_UNARY['ceil'] = logic.ceil
319
- self.KNOWN_UNARY['round'] = logic.round
320
- self.KNOWN_UNARY['sqrt'] = logic.sqrt
321
- self.KNOWN_BINARY['div'] = logic.div
322
- self.KNOWN_BINARY['mod'] = logic.mod
323
- self.KNOWN_BINARY['fmod'] = logic.mod
324
- self.IF_HELPER = logic.control_if
325
- self.SWITCH_HELPER = logic.control_switch
326
- self.BERNOULLI_HELPER = logic.bernoulli
327
- self.DISCRETE_HELPER = logic.discrete
328
- self.POISSON_HELPER = logic.poisson
329
- self.GEOMETRIC_HELPER = logic.geometric
293
+ self.OPS = logic.get_operator_dicts()
330
294
 
331
295
  def _jax_stop_grad(self, jax_expr):
332
296
  def _jax_wrapped_stop_grad(x, params, key):
@@ -575,7 +539,7 @@ class JaxStraightLinePlan(JaxPlan):
575
539
  def _jax_non_bool_param_to_action(var, param, hyperparams):
576
540
  if wrap_non_bool:
577
541
  lower, upper = bounds_safe[var]
578
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
542
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
579
543
  for mask in cond_lists[var]]
580
544
  action = (
581
545
  mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
@@ -660,7 +624,7 @@ class JaxStraightLinePlan(JaxPlan):
660
624
  action = _jax_non_bool_param_to_action(var, action, hyperparams)
661
625
  action = jnp.clip(action, *bounds[var])
662
626
  if ranges[var] == 'int':
663
- action = jnp.round(action).astype(compiled.INT)
627
+ action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
664
628
  actions[var] = action
665
629
  return actions
666
630
 
@@ -961,12 +925,11 @@ class JaxDeepReactivePolicy(JaxPlan):
961
925
  non_bool_dims = 0
962
926
  for (var, values) in observed_vars.items():
963
927
  if ranges[var] != 'bool':
964
- value_size = np.atleast_1d(values).size
928
+ value_size = np.size(values)
965
929
  if normalize_per_layer and value_size == 1:
966
930
  raise_warning(
967
931
  f'Cannot apply layer norm to state-fluent <{var}> '
968
- f'of size 1: setting normalize_per_layer = False.',
969
- 'red')
932
+ f'of size 1: setting normalize_per_layer = False.', 'red')
970
933
  normalize_per_layer = False
971
934
  non_bool_dims += value_size
972
935
  if not normalize_per_layer and non_bool_dims == 1:
@@ -990,9 +953,11 @@ class JaxDeepReactivePolicy(JaxPlan):
990
953
  else:
991
954
  if normalize and normalize_per_layer:
992
955
  normalizer = hk.LayerNorm(
993
- axis=-1, param_axis=-1,
956
+ axis=-1,
957
+ param_axis=-1,
994
958
  name=f'input_norm_{input_names[var]}',
995
- **self._normalizer_kwargs)
959
+ **self._normalizer_kwargs
960
+ )
996
961
  state = normalizer(state)
997
962
  states_non_bool.append(state)
998
963
  non_bool_dims += state.size
@@ -1001,8 +966,11 @@ class JaxDeepReactivePolicy(JaxPlan):
1001
966
  # optionally perform layer normalization on the non-bool inputs
1002
967
  if normalize and not normalize_per_layer and non_bool_dims:
1003
968
  normalizer = hk.LayerNorm(
1004
- axis=-1, param_axis=-1, name='input_norm',
1005
- **self._normalizer_kwargs)
969
+ axis=-1,
970
+ param_axis=-1,
971
+ name='input_norm',
972
+ **self._normalizer_kwargs
973
+ )
1006
974
  normalized = normalizer(state[:non_bool_dims])
1007
975
  state = state.at[:non_bool_dims].set(normalized)
1008
976
  return state
@@ -1021,7 +989,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1021
989
  actions = {}
1022
990
  for (var, size) in layer_sizes.items():
1023
991
  linear = hk.Linear(size, name=layer_names[var], w_init=init)
1024
- reshape = hk.Reshape(output_shape=shapes[var], preserve_dims=-1,
992
+ reshape = hk.Reshape(output_shape=shapes[var],
993
+ preserve_dims=-1,
1025
994
  name=f'reshape_{layer_names[var]}')
1026
995
  output = reshape(linear(hidden))
1027
996
  if not shapes[var]:
@@ -1034,7 +1003,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1034
1003
  else:
1035
1004
  if wrap_non_bool:
1036
1005
  lower, upper = bounds_safe[var]
1037
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
1006
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
1038
1007
  for mask in cond_lists[var]]
1039
1008
  action = (
1040
1009
  mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
@@ -1048,8 +1017,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1048
1017
 
1049
1018
  # for constraint satisfaction wrap bool actions with softmax
1050
1019
  if use_constraint_satisfaction:
1051
- linear = hk.Linear(
1052
- bool_action_count, name='output_bool', w_init=init)
1020
+ linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
1053
1021
  output = jax.nn.softmax(linear(hidden))
1054
1022
  actions[bool_key] = output
1055
1023
 
@@ -1087,8 +1055,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1087
1055
 
1088
1056
  # test action prediction
1089
1057
  def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
1090
- actions = _jax_wrapped_drp_predict_train(
1091
- key, params, hyperparams, step, subs)
1058
+ actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
1092
1059
  new_actions = {}
1093
1060
  for (var, action) in actions.items():
1094
1061
  prange = ranges[var]
@@ -1096,7 +1063,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1096
1063
  new_action = action > 0.5
1097
1064
  elif prange == 'int':
1098
1065
  action = jnp.clip(action, *bounds[var])
1099
- new_action = jnp.round(action).astype(compiled.INT)
1066
+ new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
1100
1067
  else:
1101
1068
  new_action = jnp.clip(action, *bounds[var])
1102
1069
  new_actions[var] = new_action
@@ -1247,17 +1214,22 @@ class GaussianPGPE(PGPE):
1247
1214
  init_sigma: float=1.0,
1248
1215
  sigma_range: Tuple[float, float]=(1e-5, 1e5),
1249
1216
  scale_reward: bool=True,
1217
+ min_reward_scale: float=1e-5,
1250
1218
  super_symmetric: bool=True,
1251
1219
  super_symmetric_accurate: bool=True,
1252
1220
  optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1253
1221
  optimizer_kwargs_mu: Optional[Kwargs]=None,
1254
- optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1222
+ optimizer_kwargs_sigma: Optional[Kwargs]=None,
1223
+ start_entropy_coeff: float=1e-3,
1224
+ end_entropy_coeff: float=1e-8,
1225
+ max_kl_update: Optional[float]=None) -> None:
1255
1226
  '''Creates a new Gaussian PGPE planner.
1256
1227
 
1257
1228
  :param batch_size: how many policy parameters to sample per optimization step
1258
1229
  :param init_sigma: initial standard deviation of Gaussian
1259
1230
  :param sigma_range: bounds to constrain standard deviation
1260
1231
  :param scale_reward: whether to apply reward scaling as in the paper
1232
+ :param min_reward_scale: minimum reward scaling to avoid underflow
1261
1233
  :param super_symmetric: whether to use super-symmetric sampling as in the paper
1262
1234
  :param super_symmetric_accurate: whether to use the accurate formula for super-
1263
1235
  symmetric sampling or the simplified but biased formula
@@ -1266,6 +1238,9 @@ class GaussianPGPE(PGPE):
1266
1238
  factory for the mean optimizer
1267
1239
  :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1268
1240
  factory for the standard deviation optimizer
1241
+ :param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
1242
+ :param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
1243
+ :param max_kl_update: bound on kl-divergence between parameter updates
1269
1244
  '''
1270
1245
  super().__init__()
1271
1246
 
@@ -1273,8 +1248,13 @@ class GaussianPGPE(PGPE):
1273
1248
  self.init_sigma = init_sigma
1274
1249
  self.sigma_range = sigma_range
1275
1250
  self.scale_reward = scale_reward
1251
+ self.min_reward_scale = min_reward_scale
1276
1252
  self.super_symmetric = super_symmetric
1277
1253
  self.super_symmetric_accurate = super_symmetric_accurate
1254
+
1255
+ # entropy regularization penalty is decayed exponentially between these values
1256
+ self.start_entropy_coeff = start_entropy_coeff
1257
+ self.end_entropy_coeff = end_entropy_coeff
1278
1258
 
1279
1259
  # set optimizers
1280
1260
  if optimizer_kwargs_mu is None:
@@ -1284,36 +1264,62 @@ class GaussianPGPE(PGPE):
1284
1264
  optimizer_kwargs_sigma = {'learning_rate': 0.1}
1285
1265
  self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1286
1266
  self.optimizer_name = optimizer
1287
- mu_optimizer = optimizer(**optimizer_kwargs_mu)
1288
- sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1267
+ try:
1268
+ mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1269
+ sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1270
+ except Exception as _:
1271
+ raise_warning(
1272
+ f'Failed to inject hyperparameters into optax optimizer for PGPE, '
1273
+ 'rolling back to safer method: please note that kl-divergence '
1274
+ 'constraints will be disabled.', 'red')
1275
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1276
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1277
+ max_kl_update = None
1289
1278
  self.optimizers = (mu_optimizer, sigma_optimizer)
1279
+ self.max_kl = max_kl_update
1290
1280
 
1291
1281
  def __str__(self) -> str:
1292
1282
  return (f'PGPE hyper-parameters:\n'
1293
- f' method ={self.__class__.__name__}\n'
1294
- f' batch_size ={self.batch_size}\n'
1295
- f' init_sigma ={self.init_sigma}\n'
1296
- f' sigma_range ={self.sigma_range}\n'
1297
- f' scale_reward ={self.scale_reward}\n'
1298
- f' super_symmetric={self.super_symmetric}\n'
1299
- f' accurate ={self.super_symmetric_accurate}\n'
1300
- f' optimizer ={self.optimizer_name}\n'
1283
+ f' method ={self.__class__.__name__}\n'
1284
+ f' batch_size ={self.batch_size}\n'
1285
+ f' init_sigma ={self.init_sigma}\n'
1286
+ f' sigma_range ={self.sigma_range}\n'
1287
+ f' scale_reward ={self.scale_reward}\n'
1288
+ f' min_reward_scale ={self.min_reward_scale}\n'
1289
+ f' super_symmetric ={self.super_symmetric}\n'
1290
+ f' accurate ={self.super_symmetric_accurate}\n'
1291
+ f' optimizer ={self.optimizer_name}\n'
1301
1292
  f' optimizer_kwargs:\n'
1302
1293
  f' mu ={self.optimizer_kwargs_mu}\n'
1303
1294
  f' sigma={self.optimizer_kwargs_sigma}\n'
1295
+ f' start_entropy_coeff={self.start_entropy_coeff}\n'
1296
+ f' end_entropy_coeff ={self.end_entropy_coeff}\n'
1297
+ f' max_kl_update ={self.max_kl}\n'
1304
1298
  )
1305
1299
 
1306
1300
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1307
- MIN_NORM = 1e-5
1308
1301
  sigma0 = self.init_sigma
1309
1302
  sigma_range = self.sigma_range
1310
1303
  scale_reward = self.scale_reward
1304
+ min_reward_scale = self.min_reward_scale
1311
1305
  super_symmetric = self.super_symmetric
1312
1306
  super_symmetric_accurate = self.super_symmetric_accurate
1313
1307
  batch_size = self.batch_size
1314
1308
  optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1315
-
1316
- # initializer
1309
+ max_kl = self.max_kl
1310
+
1311
+ # entropy regularization penalty is decayed exponentially by elapsed budget
1312
+ start_entropy_coeff = self.start_entropy_coeff
1313
+ if start_entropy_coeff == 0:
1314
+ entropy_coeff_decay = 0
1315
+ else:
1316
+ entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
1317
+
1318
+ # ***********************************************************************
1319
+ # INITIALIZATION OF POLICY
1320
+ #
1321
+ # ***********************************************************************
1322
+
1317
1323
  def _jax_wrapped_pgpe_init(key, policy_params):
1318
1324
  mu = policy_params
1319
1325
  sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
@@ -1324,7 +1330,11 @@ class GaussianPGPE(PGPE):
1324
1330
 
1325
1331
  self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1326
1332
 
1327
- # parameter sampling functions
1333
+ # ***********************************************************************
1334
+ # PARAMETER SAMPLING FUNCTIONS
1335
+ #
1336
+ # ***********************************************************************
1337
+
1328
1338
  def _jax_wrapped_mu_noise(key, sigma):
1329
1339
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1330
1340
 
@@ -1334,19 +1344,20 @@ class GaussianPGPE(PGPE):
1334
1344
  a = (sigma - jnp.abs(epsilon)) / sigma
1335
1345
  if super_symmetric_accurate:
1336
1346
  aa = jnp.abs(a)
1347
+ aa3 = jnp.power(aa, 3)
1337
1348
  epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1338
1349
  a <= 0,
1339
- jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1340
- jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1350
+ jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
1351
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
1341
1352
  )
1342
1353
  else:
1343
1354
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1344
1355
  return epsilon_star
1345
1356
 
1346
1357
  def _jax_wrapped_sample_params(key, mu, sigma):
1347
- keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1348
- keys_pytree = jax.tree_util.tree_unflatten(
1349
- treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1358
+ treedef = jax.tree_util.tree_structure(sigma)
1359
+ keys = random.split(key, num=treedef.num_leaves)
1360
+ keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1350
1361
  epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1351
1362
  p1 = jax.tree_map(jnp.add, mu, epsilon)
1352
1363
  p2 = jax.tree_map(jnp.subtract, mu, epsilon)
@@ -1356,14 +1367,18 @@ class GaussianPGPE(PGPE):
1356
1367
  p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1357
1368
  else:
1358
1369
  epsilon_star, p3, p4 = epsilon, p1, p2
1359
- return (p1, p2, p3, p4), (epsilon, epsilon_star)
1370
+ return p1, p2, p3, p4, epsilon, epsilon_star
1360
1371
 
1361
- # policy gradient update functions
1372
+ # ***********************************************************************
1373
+ # POLICY GRADIENT CALCULATION
1374
+ #
1375
+ # ***********************************************************************
1376
+
1362
1377
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1363
1378
  if super_symmetric:
1364
1379
  if scale_reward:
1365
- scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1366
- scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1380
+ scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1381
+ scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
1367
1382
  else:
1368
1383
  scale1 = scale2 = 1.0
1369
1384
  r_mu1 = (r1 - r2) / (2 * scale1)
@@ -1371,37 +1386,37 @@ class GaussianPGPE(PGPE):
1371
1386
  grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1372
1387
  else:
1373
1388
  if scale_reward:
1374
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1389
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1375
1390
  else:
1376
1391
  scale = 1.0
1377
1392
  r_mu = (r1 - r2) / (2 * scale)
1378
1393
  grad = -r_mu * epsilon
1379
1394
  return grad
1380
1395
 
1381
- def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1396
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1382
1397
  if super_symmetric:
1383
1398
  mask = r1 + r2 >= r3 + r4
1384
1399
  epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1385
- s = epsilon_tau * epsilon_tau / sigma - sigma
1400
+ s = jnp.square(epsilon_tau) / sigma - sigma
1386
1401
  if scale_reward:
1387
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1402
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
1388
1403
  else:
1389
1404
  scale = 1.0
1390
1405
  r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1391
1406
  else:
1392
- s = epsilon * epsilon / sigma - sigma
1407
+ s = jnp.square(epsilon) / sigma - sigma
1393
1408
  if scale_reward:
1394
- scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1409
+ scale = jnp.maximum(min_reward_scale, jnp.abs(m))
1395
1410
  else:
1396
1411
  scale = 1.0
1397
1412
  r_sigma = (r1 + r2) / (2 * scale)
1398
- grad = -r_sigma * s
1413
+ grad = -(r_sigma * s + ent / sigma)
1399
1414
  return grad
1400
1415
 
1401
- def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1416
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1402
1417
  policy_hyperparams, subs, model_params):
1403
1418
  key, subkey = random.split(key)
1404
- (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1419
+ p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
1405
1420
  key, mu, sigma)
1406
1421
  r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1407
1422
  r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
@@ -1419,42 +1434,76 @@ class GaussianPGPE(PGPE):
1419
1434
  epsilon, epsilon_star
1420
1435
  )
1421
1436
  grad_sigma = jax.tree_map(
1422
- partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1437
+ partial(_jax_wrapped_sigma_grad,
1438
+ r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1423
1439
  epsilon, epsilon_star, sigma
1424
1440
  )
1425
1441
  return grad_mu, grad_sigma, r_max
1426
1442
 
1427
- def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1443
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
1428
1444
  policy_hyperparams, subs, model_params):
1429
1445
  mu, sigma = pgpe_params
1430
1446
  if batch_size == 1:
1431
1447
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1432
- key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1448
+ key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1433
1449
  else:
1434
1450
  keys = random.split(key, num=batch_size)
1435
1451
  mu_grads, sigma_grads, r_maxs = jax.vmap(
1436
1452
  _jax_wrapped_pgpe_grad,
1437
- in_axes=(0, None, None, None, None, None, None)
1438
- )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1439
- mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
1440
- sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
1453
+ in_axes=(0, None, None, None, None, None, None, None)
1454
+ )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1455
+ mu_grad, sigma_grad = jax.tree_map(
1456
+ partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1441
1457
  new_r_max = jnp.max(r_maxs)
1442
1458
  return mu_grad, sigma_grad, new_r_max
1459
+
1460
+ # ***********************************************************************
1461
+ # PARAMETER UPDATE
1462
+ #
1463
+ # ***********************************************************************
1443
1464
 
1444
- def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1465
+ def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
1466
+ return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
1467
+ jnp.square(old_sigma / sigma) +
1468
+ jnp.square((mu - old_mu) / sigma) - 1)
1469
+
1470
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1445
1471
  policy_hyperparams, subs, model_params,
1446
1472
  pgpe_opt_state):
1473
+ # regular update
1447
1474
  mu, sigma = pgpe_params
1448
1475
  mu_state, sigma_state = pgpe_opt_state
1476
+ ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1449
1477
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1450
- key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1478
+ key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1451
1479
  mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1452
1480
  sigma_updates, new_sigma_state = sigma_optimizer.update(
1453
1481
  sigma_grad, sigma_state, params=sigma)
1454
1482
  new_mu = optax.apply_updates(mu, mu_updates)
1455
- new_mu, converged = projection(new_mu, policy_hyperparams)
1456
1483
  new_sigma = optax.apply_updates(sigma, sigma_updates)
1457
1484
  new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1485
+
1486
+ # respect KL divergence contraint with old parameters
1487
+ if max_kl is not None:
1488
+ old_mu_lr = new_mu_state.hyperparams['learning_rate']
1489
+ old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1490
+ kl_terms = jax.tree_map(
1491
+ _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1492
+ total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1493
+ kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1494
+ mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1495
+ sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1496
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1497
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1498
+ sigma_grad, sigma_state, params=sigma)
1499
+ new_mu = optax.apply_updates(mu, mu_updates)
1500
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1501
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1502
+ new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1503
+ new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1504
+
1505
+ # apply projection step and finalize results
1506
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1458
1507
  new_pgpe_params = (new_mu, new_sigma)
1459
1508
  new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1460
1509
  policy_params = new_mu
@@ -1463,6 +1512,71 @@ class GaussianPGPE(PGPE):
1463
1512
  self._update = jax.jit(_jax_wrapped_pgpe_update)
1464
1513
 
1465
1514
 
1515
+ # ***********************************************************************
1516
+ # ALL VERSIONS OF RISK FUNCTIONS
1517
+ #
1518
+ # Based on the original paper "A Distributional Framework for Risk-Sensitive
1519
+ # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
1520
+ #
1521
+ # Original risk functions:
1522
+ # - entropic utility
1523
+ # - mean-variance
1524
+ # - mean-semideviation
1525
+ # - conditional value at risk with straight-through gradient trick
1526
+ #
1527
+ # ***********************************************************************
1528
+
1529
+
1530
+ @jax.jit
1531
+ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1532
+ return (-1.0 / beta) * jax.scipy.special.logsumexp(
1533
+ -beta * returns, b=1.0 / returns.size)
1534
+
1535
+
1536
+ @jax.jit
1537
+ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1538
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
1539
+
1540
+
1541
+ @jax.jit
1542
+ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
1543
+ return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
1544
+
1545
+
1546
+ @jax.jit
1547
+ def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
1548
+ mu = jnp.mean(returns)
1549
+ msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
1550
+ return mu - 0.5 * beta * msd
1551
+
1552
+
1553
+ @jax.jit
1554
+ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1555
+ mu = jnp.mean(returns)
1556
+ msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
1557
+ return mu - 0.5 * beta * msv
1558
+
1559
+
1560
+ @jax.jit
1561
+ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1562
+ var = jnp.percentile(returns, q=100 * alpha)
1563
+ mask = returns <= var
1564
+ weights = mask / jnp.maximum(1, jnp.sum(mask))
1565
+ return jnp.sum(returns * weights)
1566
+
1567
+
1568
+ UTILITY_LOOKUP = {
1569
+ 'mean': jnp.mean,
1570
+ 'mean_var': mean_variance_utility,
1571
+ 'mean_std': mean_deviation_utility,
1572
+ 'mean_semivar': mean_semivariance_utility,
1573
+ 'mean_semidev': mean_semideviation_utility,
1574
+ 'entropic': entropic_utility,
1575
+ 'exponential': entropic_utility,
1576
+ 'cvar': cvar_utility
1577
+ }
1578
+
1579
+
1466
1580
  # ***********************************************************************
1467
1581
  # ALL VERSIONS OF JAX PLANNER
1468
1582
  #
@@ -1525,8 +1639,7 @@ class JaxBackpropPlanner:
1525
1639
  reward as a form of normalization
1526
1640
  :param utility: how to aggregate return observations to compute utility
1527
1641
  of a policy or plan; must be either a function mapping jax array to a
1528
- scalar, or a a string identifying the utility function by name
1529
- ("mean", "mean_var", "entropic", or "cvar" are currently supported)
1642
+ scalar, or a a string identifying the utility function by name
1530
1643
  :param utility_kwargs: additional keyword arguments to pass hyper-
1531
1644
  parameters to the utility function call
1532
1645
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
@@ -1584,18 +1697,11 @@ class JaxBackpropPlanner:
1584
1697
  # set utility
1585
1698
  if isinstance(utility, str):
1586
1699
  utility = utility.lower()
1587
- if utility == 'mean':
1588
- utility_fn = jnp.mean
1589
- elif utility == 'mean_var':
1590
- utility_fn = mean_variance_utility
1591
- elif utility == 'entropic':
1592
- utility_fn = entropic_utility
1593
- elif utility == 'cvar':
1594
- utility_fn = cvar_utility
1595
- else:
1700
+ utility_fn = UTILITY_LOOKUP.get(utility, None)
1701
+ if utility_fn is None:
1596
1702
  raise RDDLNotImplementedError(
1597
- f'Utility function <{utility}> is not supported: '
1598
- 'must be one of ["mean", "mean_var", "entropic", "cvar"].')
1703
+ f'Utility <{utility}> is not supported, '
1704
+ f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1599
1705
  else:
1600
1706
  utility_fn = utility
1601
1707
  self.utility = utility_fn
@@ -1746,7 +1852,6 @@ r"""
1746
1852
 
1747
1853
  # optimization
1748
1854
  self.update = self._jax_update(train_loss)
1749
- self.check_zero_grad = self._jax_check_zero_gradients()
1750
1855
 
1751
1856
  # pgpe option
1752
1857
  if self.use_pgpe:
@@ -1809,6 +1914,12 @@ r"""
1809
1914
  projection = self.plan.projection
1810
1915
  use_ls = self.line_search_kwargs is not None
1811
1916
 
1917
+ # check if the gradients are all zeros
1918
+ def _jax_wrapped_zero_gradients(grad):
1919
+ leaves, _ = jax.tree_util.tree_flatten(
1920
+ jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
1921
+ return jnp.all(jnp.asarray(leaves))
1922
+
1812
1923
  # calculate the plan gradient w.r.t. return loss and update optimizer
1813
1924
  # also perform a projection step to satisfy constraints on actions
1814
1925
  def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
@@ -1833,23 +1944,12 @@ r"""
1833
1944
  policy_params, converged = projection(policy_params, policy_hyperparams)
1834
1945
  log['grad'] = grad
1835
1946
  log['updates'] = updates
1947
+ zero_grads = _jax_wrapped_zero_gradients(grad)
1836
1948
  return policy_params, converged, opt_state, opt_aux, \
1837
- loss_val, log, model_params
1949
+ loss_val, log, model_params, zero_grads
1838
1950
 
1839
1951
  return jax.jit(_jax_wrapped_plan_update)
1840
1952
 
1841
- def _jax_check_zero_gradients(self):
1842
-
1843
- def _jax_wrapped_zero_gradient(grad):
1844
- return jnp.allclose(grad, 0)
1845
-
1846
- def _jax_wrapped_zero_gradients(grad):
1847
- leaves, _ = jax.tree_util.tree_flatten(
1848
- jax.tree_map(_jax_wrapped_zero_gradient, grad))
1849
- return jnp.all(jnp.asarray(leaves))
1850
-
1851
- return jax.jit(_jax_wrapped_zero_gradients)
1852
-
1853
1953
  def _batched_init_subs(self, subs):
1854
1954
  rddl = self.rddl
1855
1955
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -1865,7 +1965,7 @@ r"""
1865
1965
  f'{set(self.test_compiled.init_values.keys())}.')
1866
1966
  value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
1867
1967
  train_value = np.repeat(value, repeats=n_train, axis=0)
1868
- train_value = train_value.astype(self.compiled.REAL)
1968
+ train_value = np.asarray(train_value, dtype=self.compiled.REAL)
1869
1969
  init_train[name] = train_value
1870
1970
  init_test[name] = np.repeat(value, repeats=n_test, axis=0)
1871
1971
 
@@ -2153,11 +2253,12 @@ r"""
2153
2253
  # ======================================================================
2154
2254
 
2155
2255
  # initialize running statistics
2156
- best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
2256
+ best_params, best_loss, best_grad = policy_params, jnp.inf, None
2157
2257
  last_iter_improve = 0
2158
2258
  rolling_test_loss = RollingMean(test_rolling_window)
2159
2259
  log = {}
2160
2260
  status = JaxPlannerStatus.NORMAL
2261
+ progress_percent = 0
2161
2262
 
2162
2263
  # initialize stopping criterion
2163
2264
  if stopping_rule is not None:
@@ -2169,16 +2270,19 @@ r"""
2169
2270
  dashboard_id, dashboard.get_planner_info(self),
2170
2271
  key=dash_key, viz=self.dashboard_viz)
2171
2272
 
2273
+ # progress bar
2274
+ if print_progress:
2275
+ progress_bar = tqdm(None, total=100, position=tqdm_position,
2276
+ bar_format='{l_bar}{bar}| {elapsed} {postfix}')
2277
+ else:
2278
+ progress_bar = None
2279
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2280
+
2172
2281
  # ======================================================================
2173
2282
  # MAIN TRAINING LOOP BEGINS
2174
2283
  # ======================================================================
2175
2284
 
2176
- iters = range(epochs)
2177
- if print_progress:
2178
- iters = tqdm(iters, total=100, position=tqdm_position)
2179
- position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2180
-
2181
- for it in iters:
2285
+ for it in range(epochs):
2182
2286
 
2183
2287
  # ==================================================================
2184
2288
  # NEXT GRADIENT DESCENT STEP
@@ -2189,8 +2293,9 @@ r"""
2189
2293
  # update the parameters of the plan
2190
2294
  key, subkey = random.split(key)
2191
2295
  (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2192
- model_params) = self.update(subkey, policy_params, policy_hyperparams,
2193
- train_subs, model_params, opt_state, opt_aux)
2296
+ model_params, zero_grads) = self.update(
2297
+ subkey, policy_params, policy_hyperparams, train_subs, model_params,
2298
+ opt_state, opt_aux)
2194
2299
  test_loss, (test_log, model_params_test) = self.test_loss(
2195
2300
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2196
2301
  test_loss_smooth = rolling_test_loss.update(test_loss)
@@ -2200,8 +2305,9 @@ r"""
2200
2305
  if self.use_pgpe:
2201
2306
  key, subkey = random.split(key)
2202
2307
  pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2203
- self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2204
- test_subs, model_params, pgpe_opt_state)
2308
+ self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2309
+ policy_hyperparams, test_subs, model_params_test,
2310
+ pgpe_opt_state)
2205
2311
  pgpe_loss, _ = self.test_loss(
2206
2312
  subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2207
2313
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
@@ -2228,7 +2334,7 @@ r"""
2228
2334
  # ==================================================================
2229
2335
 
2230
2336
  # no progress
2231
- if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
2337
+ if (not pgpe_improve) and zero_grads:
2232
2338
  status = JaxPlannerStatus.NO_PROGRESS
2233
2339
 
2234
2340
  # constraint satisfaction problem
@@ -2256,7 +2362,8 @@ r"""
2256
2362
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
2257
2363
 
2258
2364
  # build a callback
2259
- progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
2365
+ progress_percent = 100 * min(
2366
+ 1, max(0, elapsed / train_seconds, it / (epochs - 1)))
2260
2367
  callback = {
2261
2368
  'status': status,
2262
2369
  'iteration': it,
@@ -2279,19 +2386,22 @@ r"""
2279
2386
  'train_log': train_log,
2280
2387
  **test_log
2281
2388
  }
2282
-
2389
+
2283
2390
  # stopping condition reached
2284
2391
  if stopping_rule is not None and stopping_rule.monitor(callback):
2285
2392
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
2286
2393
 
2287
2394
  # if the progress bar is used
2288
2395
  if print_progress:
2289
- iters.n = progress_percent
2290
- iters.set_description(
2396
+ progress_bar.set_description(
2291
2397
  f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2292
2398
  f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2293
- f'{status.value} status / {total_pgpe_it:6} pgpe'
2399
+ f'{status.value} status / {total_pgpe_it:6} pgpe',
2400
+ refresh=False
2294
2401
  )
2402
+ progress_bar.set_postfix_str(
2403
+ f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
2404
+ progress_bar.update(progress_percent - progress_bar.n)
2295
2405
 
2296
2406
  # dash-board
2297
2407
  if dashboard is not None:
@@ -2312,7 +2422,7 @@ r"""
2312
2422
 
2313
2423
  # release resources
2314
2424
  if print_progress:
2315
- iters.close()
2425
+ progress_bar.close()
2316
2426
 
2317
2427
  # validate the test return
2318
2428
  if log:
@@ -2332,7 +2442,7 @@ r"""
2332
2442
  last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
2333
2443
  print(f'summary of optimization:\n'
2334
2444
  f' status ={status}\n'
2335
- f' time ={elapsed:.6f} sec.\n'
2445
+ f' time ={elapsed:.3f} sec.\n'
2336
2446
  f' iterations ={it}\n'
2337
2447
  f' best objective={-best_loss:.6f}\n'
2338
2448
  f' best grad norm={grad_norm}\n'
@@ -2358,12 +2468,12 @@ r"""
2358
2468
  return termcolor.colored(
2359
2469
  '[FAILURE] no progress was made '
2360
2470
  f'and max grad norm {max_grad_norm:.6f} was zero: '
2361
- 'the solver was likely stuck in a plateau.', 'red')
2471
+ 'solver likely stuck in a plateau.', 'red')
2362
2472
  else:
2363
2473
  return termcolor.colored(
2364
2474
  '[FAILURE] no progress was made '
2365
2475
  f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2366
- 'the learning rate or other hyper-parameters were likely suboptimal.',
2476
+ 'learning rate or other hyper-parameters likely suboptimal.',
2367
2477
  'red')
2368
2478
 
2369
2479
  # model is likely poor IF:
@@ -2372,8 +2482,8 @@ r"""
2372
2482
  return termcolor.colored(
2373
2483
  '[WARNING] progress was made '
2374
2484
  f'but relative train-test error {validation_error:.6f} was high: '
2375
- 'model relaxation around the solution was poor '
2376
- 'or the batch size was too small.', 'yellow')
2485
+ 'poor model relaxation around solution or batch size too small.',
2486
+ 'yellow')
2377
2487
 
2378
2488
  # model likely did not converge IF:
2379
2489
  # 1. the max grad relative to the return is high
@@ -2383,9 +2493,9 @@ r"""
2383
2493
  return termcolor.colored(
2384
2494
  '[WARNING] progress was made '
2385
2495
  f'but max grad norm {max_grad_norm:.6f} was high: '
2386
- 'the solution was likely locally suboptimal, '
2387
- 'or the relaxed model was not smooth around the solution, '
2388
- 'or the batch size was too small.', 'yellow')
2496
+ 'solution locally suboptimal '
2497
+ 'or relaxed model not smooth around solution '
2498
+ 'or batch size too small.', 'yellow')
2389
2499
 
2390
2500
  # likely successful
2391
2501
  return termcolor.colored(
@@ -2412,8 +2522,7 @@ r"""
2412
2522
  for (var, values) in subs.items():
2413
2523
 
2414
2524
  # must not be grounded
2415
- if RDDLPlanningModel.FLUENT_SEP in var \
2416
- or RDDLPlanningModel.OBJECT_SEP in var:
2525
+ if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
2417
2526
  raise ValueError(f'State dictionary passed to the JAX policy is '
2418
2527
  f'grounded, since it contains the key <{var}>, '
2419
2528
  f'but a vectorized environment is required: '
@@ -2421,9 +2530,8 @@ r"""
2421
2530
 
2422
2531
  # must be numeric array
2423
2532
  # exception is for POMDPs at 1st epoch when observ-fluents are None
2424
- dtype = np.atleast_1d(values).dtype
2425
- if not np.issubdtype(dtype, np.number) \
2426
- and not np.issubdtype(dtype, np.bool_):
2533
+ dtype = np.result_type(values)
2534
+ if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
2427
2535
  if step == 0 and var in self.rddl.observ_fluents:
2428
2536
  subs[var] = self.test_compiled.init_values[var]
2429
2537
  else:
@@ -2435,40 +2543,7 @@ r"""
2435
2543
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2436
2544
  actions = jax.tree_map(np.asarray, actions)
2437
2545
  return actions
2438
-
2439
-
2440
- # ***********************************************************************
2441
- # ALL VERSIONS OF RISK FUNCTIONS
2442
- #
2443
- # Based on the original paper "A Distributional Framework for Risk-Sensitive
2444
- # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
2445
- #
2446
- # Original risk functions:
2447
- # - entropic utility
2448
- # - mean-variance approximation
2449
- # - conditional value at risk with straight-through gradient trick
2450
- #
2451
- # ***********************************************************************
2452
-
2453
-
2454
- @jax.jit
2455
- def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2456
- return (-1.0 / beta) * jax.scipy.special.logsumexp(
2457
- -beta * returns, b=1.0 / returns.size)
2458
-
2459
-
2460
- @jax.jit
2461
- def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2462
- return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
2463
-
2464
-
2465
- @jax.jit
2466
- def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2467
- var = jnp.percentile(returns, q=100 * alpha)
2468
- mask = returns <= var
2469
- weights = mask / jnp.maximum(1, jnp.sum(mask))
2470
- return jnp.sum(returns * weights)
2471
-
2546
+
2472
2547
 
2473
2548
  # ***********************************************************************
2474
2549
  # ALL VERSIONS OF CONTROLLERS
@@ -2580,8 +2655,7 @@ class JaxOnlineController(BaseAgent):
2580
2655
  self.callback = callback
2581
2656
  params = callback['best_params']
2582
2657
  self.key, subkey = random.split(self.key)
2583
- actions = planner.get_action(
2584
- subkey, params, 0, state, self.eval_hyperparams)
2658
+ actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2585
2659
  if self.warm_start:
2586
2660
  self.guess = planner.plan.guess_next_epoch(params)
2587
2661
  return actions