pyRDDLGym-jax 1.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,43 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # RELEVANT SOURCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
14
+ # Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
15
+ # the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
16
+ #
17
+ # [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
18
+ # reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
19
+ # AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
20
+ #
21
+ # [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
22
+ # nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
23
+ #
24
+ # [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
25
+ # policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
26
+ # Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
27
+ #
28
+ # ***********************************************************************
29
+
30
+
1
31
  from ast import literal_eval
2
32
  from collections import deque
3
33
  import configparser
4
34
  from enum import Enum
35
+ from functools import partial
5
36
  import os
6
37
  import sys
7
38
  import time
8
39
  import traceback
9
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
40
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
10
41
 
11
42
  import haiku as hk
12
43
  import jax
@@ -163,7 +194,20 @@ def _load_config(config, args):
163
194
  del planner_args['optimizer']
164
195
  else:
165
196
  planner_args['optimizer'] = optimizer
166
-
197
+
198
+ # pgpe optimizer
199
+ pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
200
+ pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
201
+ if pgpe_method is not None:
202
+ if 'optimizer' in pgpe_kwargs:
203
+ pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
204
+ if pgpe_optimizer is None:
205
+ raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
206
+ del pgpe_kwargs['optimizer']
207
+ else:
208
+ pgpe_kwargs['optimizer'] = pgpe_optimizer
209
+ planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
210
+
167
211
  # optimize call RNG key
168
212
  planner_key = train_args.get('key', None)
169
213
  if planner_key is not None:
@@ -469,16 +513,16 @@ class JaxStraightLinePlan(JaxPlan):
469
513
  bounds = '\n '.join(
470
514
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
471
515
  return (f'policy hyper-parameters:\n'
472
- f' initializer ={self._initializer_base}\n'
473
- f'constraint-sat strategy (simple):\n'
474
- f' parsed_action_bounds =\n {bounds}\n'
475
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
476
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
477
- f' wrap_non_bool ={self._wrap_non_bool}\n'
478
- f'constraint-sat strategy (complex):\n'
479
- f' wrap_softmax ={self._wrap_softmax}\n'
480
- f' use_new_projection ={self._use_new_projection}\n'
481
- f' max_projection_iters ={self._max_constraint_iter}')
516
+ f' initializer={self._initializer_base}\n'
517
+ f' constraint-sat strategy (simple):\n'
518
+ f' parsed_action_bounds =\n {bounds}\n'
519
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
520
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
521
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
522
+ f' constraint-sat strategy (complex):\n'
523
+ f' wrap_softmax ={self._wrap_softmax}\n'
524
+ f' use_new_projection ={self._use_new_projection}\n'
525
+ f' max_projection_iters={self._max_constraint_iter}\n')
482
526
 
483
527
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
484
528
  _bounds: Bounds,
@@ -655,7 +699,10 @@ class JaxStraightLinePlan(JaxPlan):
655
699
  if ranges[var] == 'bool':
656
700
  param_flat = jnp.ravel(param)
657
701
  if noop[var]:
658
- param_flat = (-param_flat) if wrap_sigmoid else 1.0 - param_flat
702
+ if wrap_sigmoid:
703
+ param_flat = -param_flat
704
+ else:
705
+ param_flat = 1.0 - param_flat
659
706
  scores.append(param_flat)
660
707
  scores = jnp.concatenate(scores)
661
708
  descending = jnp.sort(scores)[::-1]
@@ -666,7 +713,10 @@ class JaxStraightLinePlan(JaxPlan):
666
713
  new_params = {}
667
714
  for (var, param) in params.items():
668
715
  if ranges[var] == 'bool':
669
- new_param = param + (surplus if noop[var] else -surplus)
716
+ if noop[var]:
717
+ new_param = param + surplus
718
+ else:
719
+ new_param = param - surplus
670
720
  new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
671
721
  else:
672
722
  new_param = param
@@ -687,57 +737,73 @@ class JaxStraightLinePlan(JaxPlan):
687
737
  elif use_constraint_satisfaction and not self._use_new_projection:
688
738
 
689
739
  # calculate the surplus of actions above max-nondef-actions
690
- def _jax_wrapped_sogbofa_surplus(params, hyperparams):
691
- sum_action, count = 0.0, 0
692
- for (var, param) in params.items():
740
+ def _jax_wrapped_sogbofa_surplus(actions):
741
+ sum_action, k = 0.0, 0
742
+ for (var, action) in actions.items():
693
743
  if ranges[var] == 'bool':
694
- action = _jax_bool_param_to_action(var, param, hyperparams)
695
744
  if noop[var]:
696
- sum_action += jnp.size(action) - jnp.sum(action)
697
- count += jnp.sum(action < 1)
698
- else:
699
- sum_action += jnp.sum(action)
700
- count += jnp.sum(action > 0)
745
+ action = 1 - action
746
+ sum_action += jnp.sum(action)
747
+ k += jnp.count_nonzero(action)
701
748
  surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
702
- count = jnp.maximum(count, 1)
703
- return surplus / count
749
+ return surplus, k
704
750
 
705
751
  # return whether the surplus is positive or reached compute limit
706
752
  max_constraint_iter = self._max_constraint_iter
707
753
 
708
754
  def _jax_wrapped_sogbofa_continue(values):
709
- it, _, _, surplus = values
710
- return jnp.logical_and(it < max_constraint_iter, surplus > 0)
755
+ it, _, surplus, k = values
756
+ return jnp.logical_and(
757
+ it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
711
758
 
712
759
  # reduce all bool action values by the surplus clipping at minimum
713
760
  # for no-op = True, do the opposite, i.e. increase all
714
761
  # bool action values by surplus clipping at maximum
715
762
  def _jax_wrapped_sogbofa_subtract_surplus(values):
716
- it, params, hyperparams, surplus = values
717
- new_params = {}
718
- for (var, param) in params.items():
763
+ it, actions, surplus, k = values
764
+ amount = surplus / k
765
+ new_actions = {}
766
+ for (var, action) in actions.items():
719
767
  if ranges[var] == 'bool':
720
- action = _jax_bool_param_to_action(var, param, hyperparams)
721
- new_action = action + (surplus if noop[var] else -surplus)
722
- new_action = jnp.clip(new_action, min_action, max_action)
723
- new_param = _jax_bool_action_to_param(var, new_action, hyperparams)
768
+ if noop[var]:
769
+ new_actions[var] = jnp.minimum(action + amount, 1)
770
+ else:
771
+ new_actions[var] = jnp.maximum(action - amount, 0)
724
772
  else:
725
- new_param = param
726
- new_params[var] = new_param
727
- new_surplus = _jax_wrapped_sogbofa_surplus(new_params, hyperparams)
773
+ new_actions[var] = action
774
+ new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
728
775
  new_it = it + 1
729
- return new_it, new_params, hyperparams, new_surplus
776
+ return new_it, new_actions, new_surplus, new_k
730
777
 
731
778
  # apply the surplus to the actions until it becomes zero
732
779
  def _jax_wrapped_sogbofa_project(params, hyperparams):
733
- surplus = _jax_wrapped_sogbofa_surplus(params, hyperparams)
734
- _, params, _, surplus = jax.lax.while_loop(
780
+
781
+ # convert parameters to actions
782
+ actions = {}
783
+ for (var, param) in params.items():
784
+ if ranges[var] == 'bool':
785
+ actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
786
+ else:
787
+ actions[var] = param
788
+
789
+ # run SOGBOFA loop on the actions to get adjusted actions
790
+ surplus, k = _jax_wrapped_sogbofa_surplus(actions)
791
+ _, actions, surplus, k = jax.lax.while_loop(
735
792
  cond_fun=_jax_wrapped_sogbofa_continue,
736
793
  body_fun=_jax_wrapped_sogbofa_subtract_surplus,
737
- init_val=(0, params, hyperparams, surplus)
794
+ init_val=(0, actions, surplus, k)
738
795
  )
739
796
  converged = jnp.logical_not(surplus > 0)
740
- return params, converged
797
+
798
+ # convert the adjusted actions back to parameters
799
+ new_params = {}
800
+ for (var, action) in actions.items():
801
+ if ranges[var] == 'bool':
802
+ action = jnp.clip(action, min_action, max_action)
803
+ new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
804
+ else:
805
+ new_params[var] = action
806
+ return new_params, converged
741
807
 
742
808
  # clip actions to valid bounds and satisfy constraint on max actions
743
809
  def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
@@ -834,15 +900,16 @@ class JaxDeepReactivePolicy(JaxPlan):
834
900
  bounds = '\n '.join(
835
901
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
836
902
  return (f'policy hyper-parameters:\n'
837
- f' topology ={self._topology}\n'
838
- f' activation_fn ={self._activations[0].__name__}\n'
839
- f' initializer ={type(self._initializer_base).__name__}\n'
840
- f' apply_input_norm ={self._normalize}\n'
841
- f' input_norm_layerwise={self._normalize_per_layer}\n'
842
- f' input_norm_args ={self._normalizer_kwargs}\n'
843
- f'constraint-sat strategy:\n'
844
- f' parsed_action_bounds=\n {bounds}\n'
845
- f' wrap_non_bool ={self._wrap_non_bool}')
903
+ f' topology ={self._topology}\n'
904
+ f' activation_fn={self._activations[0].__name__}\n'
905
+ f' initializer ={type(self._initializer_base).__name__}\n'
906
+ f' input norm:\n'
907
+ f' apply_input_norm ={self._normalize}\n'
908
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
909
+ f' input_norm_args ={self._normalizer_kwargs}\n'
910
+ f' constraint-sat strategy:\n'
911
+ f' parsed_action_bounds=\n {bounds}\n'
912
+ f' wrap_non_bool ={self._wrap_non_bool}\n')
846
913
 
847
914
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
848
915
  _bounds: Bounds,
@@ -1068,10 +1135,11 @@ class JaxDeepReactivePolicy(JaxPlan):
1068
1135
 
1069
1136
 
1070
1137
  # ***********************************************************************
1071
- # ALL VERSIONS OF JAX PLANNER
1138
+ # SUPPORTING FUNCTIONS
1072
1139
  #
1073
- # - simple gradient descent based planner
1074
- # - more stable but slower line search based planner
1140
+ # - smoothed mean calculation
1141
+ # - planner status
1142
+ # - stopping criteria
1075
1143
  #
1076
1144
  # ***********************************************************************
1077
1145
 
@@ -1145,6 +1213,264 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1145
1213
  return f'No improvement for {self.patience} iterations'
1146
1214
 
1147
1215
 
1216
+ # ***********************************************************************
1217
+ # PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
1218
+ #
1219
+ # - simple Gaussian PGPE
1220
+ #
1221
+ # ***********************************************************************
1222
+
1223
+
1224
+ class PGPE:
1225
+ """Base class for all PGPE strategies."""
1226
+
1227
+ def __init__(self) -> None:
1228
+ self._initializer = None
1229
+ self._update = None
1230
+
1231
+ @property
1232
+ def initialize(self):
1233
+ return self._initializer
1234
+
1235
+ @property
1236
+ def update(self):
1237
+ return self._update
1238
+
1239
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1240
+ raise NotImplementedError
1241
+
1242
+
1243
+ class GaussianPGPE(PGPE):
1244
+ '''PGPE with a Gaussian parameter distribution.'''
1245
+
1246
+ def __init__(self, batch_size: int=1,
1247
+ init_sigma: float=1.0,
1248
+ sigma_range: Tuple[float, float]=(1e-5, 1e5),
1249
+ scale_reward: bool=True,
1250
+ super_symmetric: bool=True,
1251
+ super_symmetric_accurate: bool=True,
1252
+ optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1253
+ optimizer_kwargs_mu: Optional[Kwargs]=None,
1254
+ optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1255
+ '''Creates a new Gaussian PGPE planner.
1256
+
1257
+ :param batch_size: how many policy parameters to sample per optimization step
1258
+ :param init_sigma: initial standard deviation of Gaussian
1259
+ :param sigma_range: bounds to constrain standard deviation
1260
+ :param scale_reward: whether to apply reward scaling as in the paper
1261
+ :param super_symmetric: whether to use super-symmetric sampling as in the paper
1262
+ :param super_symmetric_accurate: whether to use the accurate formula for super-
1263
+ symmetric sampling or the simplified but biased formula
1264
+ :param optimizer: a factory for an optax SGD algorithm
1265
+ :param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
1266
+ factory for the mean optimizer
1267
+ :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1268
+ factory for the standard deviation optimizer
1269
+ '''
1270
+ super().__init__()
1271
+
1272
+ self.batch_size = batch_size
1273
+ self.init_sigma = init_sigma
1274
+ self.sigma_range = sigma_range
1275
+ self.scale_reward = scale_reward
1276
+ self.super_symmetric = super_symmetric
1277
+ self.super_symmetric_accurate = super_symmetric_accurate
1278
+
1279
+ # set optimizers
1280
+ if optimizer_kwargs_mu is None:
1281
+ optimizer_kwargs_mu = {'learning_rate': 0.1}
1282
+ self.optimizer_kwargs_mu = optimizer_kwargs_mu
1283
+ if optimizer_kwargs_sigma is None:
1284
+ optimizer_kwargs_sigma = {'learning_rate': 0.1}
1285
+ self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1286
+ self.optimizer_name = optimizer
1287
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1288
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1289
+ self.optimizers = (mu_optimizer, sigma_optimizer)
1290
+
1291
+ def __str__(self) -> str:
1292
+ return (f'PGPE hyper-parameters:\n'
1293
+ f' method ={self.__class__.__name__}\n'
1294
+ f' batch_size ={self.batch_size}\n'
1295
+ f' init_sigma ={self.init_sigma}\n'
1296
+ f' sigma_range ={self.sigma_range}\n'
1297
+ f' scale_reward ={self.scale_reward}\n'
1298
+ f' super_symmetric={self.super_symmetric}\n'
1299
+ f' accurate ={self.super_symmetric_accurate}\n'
1300
+ f' optimizer ={self.optimizer_name}\n'
1301
+ f' optimizer_kwargs:\n'
1302
+ f' mu ={self.optimizer_kwargs_mu}\n'
1303
+ f' sigma={self.optimizer_kwargs_sigma}\n'
1304
+ )
1305
+
1306
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1307
+ MIN_NORM = 1e-5
1308
+ sigma0 = self.init_sigma
1309
+ sigma_range = self.sigma_range
1310
+ scale_reward = self.scale_reward
1311
+ super_symmetric = self.super_symmetric
1312
+ super_symmetric_accurate = self.super_symmetric_accurate
1313
+ batch_size = self.batch_size
1314
+ optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1315
+
1316
+ # initializer
1317
+ def _jax_wrapped_pgpe_init(key, policy_params):
1318
+ mu = policy_params
1319
+ sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
1320
+ pgpe_params = (mu, sigma)
1321
+ pgpe_opt_state = tuple(opt.init(param)
1322
+ for (opt, param) in zip(optimizers, pgpe_params))
1323
+ return pgpe_params, pgpe_opt_state
1324
+
1325
+ self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1326
+
1327
+ # parameter sampling functions
1328
+ def _jax_wrapped_mu_noise(key, sigma):
1329
+ return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1330
+
1331
+ def _jax_wrapped_epsilon_star(sigma, epsilon):
1332
+ c1, c2, c3 = -0.06655, -0.9706, 0.124
1333
+ phi = 0.67449 * sigma
1334
+ a = (sigma - jnp.abs(epsilon)) / sigma
1335
+ if super_symmetric_accurate:
1336
+ aa = jnp.abs(a)
1337
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1338
+ a <= 0,
1339
+ jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1340
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1341
+ )
1342
+ else:
1343
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1344
+ return epsilon_star
1345
+
1346
+ def _jax_wrapped_sample_params(key, mu, sigma):
1347
+ keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1348
+ keys_pytree = jax.tree_util.tree_unflatten(
1349
+ treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1350
+ epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1351
+ p1 = jax.tree_map(jnp.add, mu, epsilon)
1352
+ p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1353
+ if super_symmetric:
1354
+ epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1355
+ p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1356
+ p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1357
+ else:
1358
+ epsilon_star, p3, p4 = epsilon, p1, p2
1359
+ return (p1, p2, p3, p4), (epsilon, epsilon_star)
1360
+
1361
+ # policy gradient update functions
1362
+ def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1363
+ if super_symmetric:
1364
+ if scale_reward:
1365
+ scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1366
+ scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1367
+ else:
1368
+ scale1 = scale2 = 1.0
1369
+ r_mu1 = (r1 - r2) / (2 * scale1)
1370
+ r_mu2 = (r3 - r4) / (2 * scale2)
1371
+ grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1372
+ else:
1373
+ if scale_reward:
1374
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1375
+ else:
1376
+ scale = 1.0
1377
+ r_mu = (r1 - r2) / (2 * scale)
1378
+ grad = -r_mu * epsilon
1379
+ return grad
1380
+
1381
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1382
+ if super_symmetric:
1383
+ mask = r1 + r2 >= r3 + r4
1384
+ epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1385
+ s = epsilon_tau * epsilon_tau / sigma - sigma
1386
+ if scale_reward:
1387
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1388
+ else:
1389
+ scale = 1.0
1390
+ r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1391
+ else:
1392
+ s = epsilon * epsilon / sigma - sigma
1393
+ if scale_reward:
1394
+ scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1395
+ else:
1396
+ scale = 1.0
1397
+ r_sigma = (r1 + r2) / (2 * scale)
1398
+ grad = -r_sigma * s
1399
+ return grad
1400
+
1401
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1402
+ policy_hyperparams, subs, model_params):
1403
+ key, subkey = random.split(key)
1404
+ (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1405
+ key, mu, sigma)
1406
+ r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1407
+ r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
1408
+ r_max = jnp.maximum(r_max, r1)
1409
+ r_max = jnp.maximum(r_max, r2)
1410
+ if super_symmetric:
1411
+ r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
1412
+ r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
1413
+ r_max = jnp.maximum(r_max, r3)
1414
+ r_max = jnp.maximum(r_max, r4)
1415
+ else:
1416
+ r3, r4 = r1, r2
1417
+ grad_mu = jax.tree_map(
1418
+ partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1419
+ epsilon, epsilon_star
1420
+ )
1421
+ grad_sigma = jax.tree_map(
1422
+ partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1423
+ epsilon, epsilon_star, sigma
1424
+ )
1425
+ return grad_mu, grad_sigma, r_max
1426
+
1427
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1428
+ policy_hyperparams, subs, model_params):
1429
+ mu, sigma = pgpe_params
1430
+ if batch_size == 1:
1431
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1432
+ key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1433
+ else:
1434
+ keys = random.split(key, num=batch_size)
1435
+ mu_grads, sigma_grads, r_maxs = jax.vmap(
1436
+ _jax_wrapped_pgpe_grad,
1437
+ in_axes=(0, None, None, None, None, None, None)
1438
+ )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1439
+ mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
1440
+ sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
1441
+ new_r_max = jnp.max(r_maxs)
1442
+ return mu_grad, sigma_grad, new_r_max
1443
+
1444
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1445
+ policy_hyperparams, subs, model_params,
1446
+ pgpe_opt_state):
1447
+ mu, sigma = pgpe_params
1448
+ mu_state, sigma_state = pgpe_opt_state
1449
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1450
+ key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1451
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1452
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1453
+ sigma_grad, sigma_state, params=sigma)
1454
+ new_mu = optax.apply_updates(mu, mu_updates)
1455
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1456
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1457
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1458
+ new_pgpe_params = (new_mu, new_sigma)
1459
+ new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1460
+ policy_params = new_mu
1461
+ return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1462
+
1463
+ self._update = jax.jit(_jax_wrapped_pgpe_update)
1464
+
1465
+
1466
+ # ***********************************************************************
1467
+ # ALL VERSIONS OF JAX PLANNER
1468
+ #
1469
+ # - simple gradient descent based planner
1470
+ #
1471
+ # ***********************************************************************
1472
+
1473
+
1148
1474
  class JaxBackpropPlanner:
1149
1475
  '''A class for optimizing an action sequence in the given RDDL MDP using
1150
1476
  gradient descent.'''
@@ -1161,6 +1487,7 @@ class JaxBackpropPlanner:
1161
1487
  clip_grad: Optional[float]=None,
1162
1488
  line_search_kwargs: Optional[Kwargs]=None,
1163
1489
  noise_kwargs: Optional[Kwargs]=None,
1490
+ pgpe: Optional[PGPE]=GaussianPGPE(),
1164
1491
  logic: Logic=FuzzyLogic(),
1165
1492
  use_symlog_reward: bool=False,
1166
1493
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
@@ -1191,6 +1518,7 @@ class JaxBackpropPlanner:
1191
1518
  :param line_search_kwargs: parameters to pass to optional line search
1192
1519
  method to scale learning rate
1193
1520
  :param noise_kwargs: parameters of optional gradient noise
1521
+ :param pgpe: optional policy gradient to run alongside the planner
1194
1522
  :param logic: a subclass of Logic for mapping exact mathematical
1195
1523
  operations to their differentiable counterparts
1196
1524
  :param use_symlog_reward: whether to use the symlog transform on the
@@ -1229,6 +1557,8 @@ class JaxBackpropPlanner:
1229
1557
  self.clip_grad = clip_grad
1230
1558
  self.line_search_kwargs = line_search_kwargs
1231
1559
  self.noise_kwargs = noise_kwargs
1560
+ self.pgpe = pgpe
1561
+ self.use_pgpe = pgpe is not None
1232
1562
 
1233
1563
  # set optimizer
1234
1564
  try:
@@ -1333,24 +1663,25 @@ r"""
1333
1663
  f' line_search_kwargs={self.line_search_kwargs}\n'
1334
1664
  f' noise_kwargs ={self.noise_kwargs}\n'
1335
1665
  f' batch_size_train ={self.batch_size_train}\n'
1336
- f' batch_size_test ={self.batch_size_test}')
1337
- result += '\n' + str(self.plan)
1338
- result += '\n' + str(self.logic)
1666
+ f' batch_size_test ={self.batch_size_test}\n')
1667
+ result += str(self.plan)
1668
+ if self.use_pgpe:
1669
+ result += str(self.pgpe)
1670
+ result += str(self.logic)
1339
1671
 
1340
1672
  # print model relaxation information
1341
- if not self.compiled.model_params:
1342
- return result
1343
- result += '\n' + ('Some RDDL operations are non-differentiable '
1344
- 'and will be approximated as follows:' + '\n')
1345
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1346
- for info in self.compiled.model_parameter_info().values():
1347
- rddl_op = info['rddl_op']
1348
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1349
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1350
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1351
- result += (f' {rddl_op}:\n'
1352
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1353
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1673
+ if self.compiled.model_params:
1674
+ result += ('Some RDDL operations are non-differentiable '
1675
+ 'and will be approximated as follows:' + '\n')
1676
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1677
+ for info in self.compiled.model_parameter_info().values():
1678
+ rddl_op = info['rddl_op']
1679
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1680
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1681
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1682
+ result += (f' {rddl_op}:\n'
1683
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1684
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1354
1685
  return result
1355
1686
 
1356
1687
  def summarize_hyperparameters(self) -> None:
@@ -1415,6 +1746,16 @@ r"""
1415
1746
 
1416
1747
  # optimization
1417
1748
  self.update = self._jax_update(train_loss)
1749
+ self.check_zero_grad = self._jax_check_zero_gradients()
1750
+
1751
+ # pgpe option
1752
+ if self.use_pgpe:
1753
+ loss_fn = self._jax_loss(rollouts=test_rollouts)
1754
+ self.pgpe.compile(
1755
+ loss_fn=loss_fn,
1756
+ projection=self.plan.projection,
1757
+ real_dtype=self.test_compiled.REAL
1758
+ )
1418
1759
 
1419
1760
  def _jax_return(self, use_symlog):
1420
1761
  gamma = self.rddl.discount
@@ -1497,6 +1838,18 @@ r"""
1497
1838
 
1498
1839
  return jax.jit(_jax_wrapped_plan_update)
1499
1840
 
1841
+ def _jax_check_zero_gradients(self):
1842
+
1843
+ def _jax_wrapped_zero_gradient(grad):
1844
+ return jnp.allclose(grad, 0)
1845
+
1846
+ def _jax_wrapped_zero_gradients(grad):
1847
+ leaves, _ = jax.tree_util.tree_flatten(
1848
+ jax.tree_map(_jax_wrapped_zero_gradient, grad))
1849
+ return jnp.all(jnp.asarray(leaves))
1850
+
1851
+ return jax.jit(_jax_wrapped_zero_gradients)
1852
+
1500
1853
  def _batched_init_subs(self, subs):
1501
1854
  rddl = self.rddl
1502
1855
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -1611,7 +1964,7 @@ r"""
1611
1964
  return grad
1612
1965
 
1613
1966
  return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1614
-
1967
+
1615
1968
  # ===========================================================================
1616
1969
  # OPTIMIZE API
1617
1970
  # ===========================================================================
@@ -1784,7 +2137,17 @@ r"""
1784
2137
  policy_params = guess
1785
2138
  opt_state = self.optimizer.init(policy_params)
1786
2139
  opt_aux = {}
1787
-
2140
+
2141
+ # initialize pgpe parameters
2142
+ if self.use_pgpe:
2143
+ pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
2144
+ rolling_pgpe_loss = RollingMean(test_rolling_window)
2145
+ else:
2146
+ pgpe_params, pgpe_opt_state = None, None
2147
+ rolling_pgpe_loss = None
2148
+ total_pgpe_it = 0
2149
+ r_max = -jnp.inf
2150
+
1788
2151
  # ======================================================================
1789
2152
  # INITIALIZATION OF RUNNING STATISTICS
1790
2153
  # ======================================================================
@@ -1795,7 +2158,6 @@ r"""
1795
2158
  rolling_test_loss = RollingMean(test_rolling_window)
1796
2159
  log = {}
1797
2160
  status = JaxPlannerStatus.NORMAL
1798
- is_all_zero_fn = lambda x: np.allclose(x, 0)
1799
2161
 
1800
2162
  # initialize stopping criterion
1801
2163
  if stopping_rule is not None:
@@ -1826,19 +2188,47 @@ r"""
1826
2188
 
1827
2189
  # update the parameters of the plan
1828
2190
  key, subkey = random.split(key)
1829
- (policy_params, converged, opt_state, opt_aux,
1830
- train_loss, train_log, model_params) = \
1831
- self.update(subkey, policy_params, policy_hyperparams,
1832
- train_subs, model_params, opt_state, opt_aux)
1833
-
2191
+ (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2192
+ model_params) = self.update(subkey, policy_params, policy_hyperparams,
2193
+ train_subs, model_params, opt_state, opt_aux)
2194
+ test_loss, (test_log, model_params_test) = self.test_loss(
2195
+ subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2196
+ test_loss_smooth = rolling_test_loss.update(test_loss)
2197
+
2198
+ # pgpe update of the plan
2199
+ pgpe_improve = False
2200
+ if self.use_pgpe:
2201
+ key, subkey = random.split(key)
2202
+ pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2203
+ self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2204
+ test_subs, model_params, pgpe_opt_state)
2205
+ pgpe_loss, _ = self.test_loss(
2206
+ subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2207
+ pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2208
+ pgpe_return = -pgpe_loss_smooth
2209
+
2210
+ # replace with PGPE if it reaches a new minimum or train loss invalid
2211
+ if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2212
+ policy_params = pgpe_param
2213
+ test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2214
+ converged = pgpe_converged
2215
+ pgpe_improve = True
2216
+ total_pgpe_it += 1
2217
+ else:
2218
+ pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2219
+
2220
+ # evaluate test losses and record best plan so far
2221
+ if test_loss_smooth < best_loss:
2222
+ best_params, best_loss, best_grad = \
2223
+ policy_params, test_loss_smooth, train_log['grad']
2224
+ last_iter_improve = it
2225
+
1834
2226
  # ==================================================================
1835
2227
  # STATUS CHECKS AND LOGGING
1836
2228
  # ==================================================================
1837
2229
 
1838
2230
  # no progress
1839
- grad_norm_zero, _ = jax.tree_util.tree_flatten(
1840
- jax.tree_map(is_all_zero_fn, train_log['grad']))
1841
- if np.all(grad_norm_zero):
2231
+ if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
1842
2232
  status = JaxPlannerStatus.NO_PROGRESS
1843
2233
 
1844
2234
  # constraint satisfaction problem
@@ -1850,21 +2240,14 @@ r"""
1850
2240
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1851
2241
 
1852
2242
  # numerical error
1853
- if not np.isfinite(train_loss):
1854
- raise_warning(
1855
- f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
2243
+ if self.use_pgpe:
2244
+ invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
2245
+ else:
2246
+ invalid_loss = not np.isfinite(train_loss)
2247
+ if invalid_loss:
2248
+ raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
1856
2249
  status = JaxPlannerStatus.INVALID_GRADIENT
1857
2250
 
1858
- # evaluate test losses and record best plan so far
1859
- test_loss, (log, model_params_test) = self.test_loss(
1860
- subkey, policy_params, policy_hyperparams,
1861
- test_subs, model_params_test)
1862
- test_loss = rolling_test_loss.update(test_loss)
1863
- if test_loss < best_loss:
1864
- best_params, best_loss, best_grad = \
1865
- policy_params, test_loss, train_log['grad']
1866
- last_iter_improve = it
1867
-
1868
2251
  # reached computation budget
1869
2252
  elapsed = time.time() - start_time - elapsed_outside_loop
1870
2253
  if elapsed >= train_seconds:
@@ -1878,11 +2261,14 @@ r"""
1878
2261
  'status': status,
1879
2262
  'iteration': it,
1880
2263
  'train_return':-train_loss,
1881
- 'test_return':-test_loss,
2264
+ 'test_return':-test_loss_smooth,
1882
2265
  'best_return':-best_loss,
2266
+ 'pgpe_return': pgpe_return,
1883
2267
  'params': policy_params,
1884
2268
  'best_params': best_params,
2269
+ 'pgpe_params': pgpe_params,
1885
2270
  'last_iteration_improved': last_iter_improve,
2271
+ 'pgpe_improved': pgpe_improve,
1886
2272
  'grad': train_log['grad'],
1887
2273
  'best_grad': best_grad,
1888
2274
  'updates': train_log['updates'],
@@ -1891,7 +2277,7 @@ r"""
1891
2277
  'model_params': model_params,
1892
2278
  'progress': progress_percent,
1893
2279
  'train_log': train_log,
1894
- **log
2280
+ **test_log
1895
2281
  }
1896
2282
 
1897
2283
  # stopping condition reached
@@ -1902,9 +2288,9 @@ r"""
1902
2288
  if print_progress:
1903
2289
  iters.n = progress_percent
1904
2290
  iters.set_description(
1905
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1906
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1907
- f'{status.value} status'
2291
+ f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2292
+ f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2293
+ f'{status.value} status / {total_pgpe_it:6} pgpe'
1908
2294
  )
1909
2295
 
1910
2296
  # dash-board
@@ -1923,7 +2309,7 @@ r"""
1923
2309
  # ======================================================================
1924
2310
  # POST-PROCESSING AND CLEANUP
1925
2311
  # ======================================================================
1926
-
2312
+
1927
2313
  # release resources
1928
2314
  if print_progress:
1929
2315
  iters.close()
@@ -1935,7 +2321,7 @@ r"""
1935
2321
  messages.update(JaxRDDLCompiler.get_error_messages(error_code))
1936
2322
  if messages:
1937
2323
  messages = '\n'.join(messages)
1938
- raise_warning('The JAX compiler encountered the following '
2324
+ raise_warning('JAX compiler encountered the following '
1939
2325
  'error(s) in the original RDDL formulation '
1940
2326
  f'during test evaluation:\n{messages}', 'red')
1941
2327
 
@@ -1943,14 +2329,14 @@ r"""
1943
2329
  if print_summary:
1944
2330
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1945
2331
  diagnosis = self._perform_diagnosis(
1946
- last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
2332
+ last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
1947
2333
  print(f'summary of optimization:\n'
1948
- f' status_code ={status}\n'
1949
- f' time_elapsed ={elapsed}\n'
2334
+ f' status ={status}\n'
2335
+ f' time ={elapsed:.6f} sec.\n'
1950
2336
  f' iterations ={it}\n'
1951
- f' best_objective={-best_loss}\n'
1952
- f' best_grad_norm={grad_norm}\n'
1953
- f' diagnosis: {diagnosis}\n')
2337
+ f' best objective={-best_loss:.6f}\n'
2338
+ f' best grad norm={grad_norm}\n'
2339
+ f'diagnosis: {diagnosis}\n')
1954
2340
 
1955
2341
  def _perform_diagnosis(self, last_iter_improve,
1956
2342
  train_return, test_return, best_return, grad_norm):
@@ -1970,23 +2356,24 @@ r"""
1970
2356
  if last_iter_improve <= 1:
1971
2357
  if grad_is_zero:
1972
2358
  return termcolor.colored(
1973
- '[FAILURE] no progress was made, '
1974
- f'and max grad norm {max_grad_norm:.6f} is zero: '
1975
- 'solver likely stuck in a plateau.', 'red')
2359
+ '[FAILURE] no progress was made '
2360
+ f'and max grad norm {max_grad_norm:.6f} was zero: '
2361
+ 'the solver was likely stuck in a plateau.', 'red')
1976
2362
  else:
1977
2363
  return termcolor.colored(
1978
- '[FAILURE] no progress was made, '
1979
- f'but max grad norm {max_grad_norm:.6f} is non-zero: '
1980
- 'likely poor learning rate or other hyper-parameter.', 'red')
2364
+ '[FAILURE] no progress was made '
2365
+ f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2366
+ 'the learning rate or other hyper-parameters were likely suboptimal.',
2367
+ 'red')
1981
2368
 
1982
2369
  # model is likely poor IF:
1983
2370
  # 1. the train and test return disagree
1984
2371
  if not (validation_error < 20):
1985
2372
  return termcolor.colored(
1986
- '[WARNING] progress was made, '
1987
- f'but relative train-test error {validation_error:.6f} is high: '
1988
- 'likely poor model relaxation around the solution, '
1989
- 'or the batch size is too small.', 'yellow')
2373
+ '[WARNING] progress was made '
2374
+ f'but relative train-test error {validation_error:.6f} was high: '
2375
+ 'model relaxation around the solution was poor '
2376
+ 'or the batch size was too small.', 'yellow')
1990
2377
 
1991
2378
  # model likely did not converge IF:
1992
2379
  # 1. the max grad relative to the return is high
@@ -1994,15 +2381,15 @@ r"""
1994
2381
  return_to_grad_norm = abs(best_return) / max_grad_norm
1995
2382
  if not (return_to_grad_norm > 1):
1996
2383
  return termcolor.colored(
1997
- '[WARNING] progress was made, '
1998
- f'but max grad norm {max_grad_norm:.6f} is high: '
1999
- 'likely the solution is not locally optimal, '
2000
- 'or the relaxed model is not smooth around the solution, '
2001
- 'or the batch size is too small.', 'yellow')
2384
+ '[WARNING] progress was made '
2385
+ f'but max grad norm {max_grad_norm:.6f} was high: '
2386
+ 'the solution was likely locally suboptimal, '
2387
+ 'or the relaxed model was not smooth around the solution, '
2388
+ 'or the batch size was too small.', 'yellow')
2002
2389
 
2003
2390
  # likely successful
2004
2391
  return termcolor.colored(
2005
- '[SUCCESS] planner has converged successfully '
2392
+ '[SUCCESS] solver converged successfully '
2006
2393
  '(note: not all potential problems can be ruled out).', 'green')
2007
2394
 
2008
2395
  def get_action(self, key: random.PRNGKey,
@@ -2035,8 +2422,8 @@ r"""
2035
2422
  # must be numeric array
2036
2423
  # exception is for POMDPs at 1st epoch when observ-fluents are None
2037
2424
  dtype = np.atleast_1d(values).dtype
2038
- if not jnp.issubdtype(dtype, jnp.number) \
2039
- and not jnp.issubdtype(dtype, jnp.bool_):
2425
+ if not np.issubdtype(dtype, np.number) \
2426
+ and not np.issubdtype(dtype, np.bool_):
2040
2427
  if step == 0 and var in self.rddl.observ_fluents:
2041
2428
  subs[var] = self.test_compiled.init_values[var]
2042
2429
  else:
@@ -2077,10 +2464,11 @@ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2077
2464
 
2078
2465
  @jax.jit
2079
2466
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2080
- alpha_mask = jax.lax.stop_gradient(
2081
- returns <= jnp.percentile(returns, q=100 * alpha))
2082
- return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2083
-
2467
+ var = jnp.percentile(returns, q=100 * alpha)
2468
+ mask = returns <= var
2469
+ weights = mask / jnp.maximum(1, jnp.sum(mask))
2470
+ return jnp.sum(returns * weights)
2471
+
2084
2472
 
2085
2473
  # ***********************************************************************
2086
2474
  # ALL VERSIONS OF CONTROLLERS