pyRDDLGym-jax 1.3__tar.gz → 2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/PKG-INFO +1 -1
  2. pyrddlgym_jax-2.0/pyRDDLGym_jax/__init__.py +1 -0
  3. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/compiler.py +16 -1
  4. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/logic.py +36 -9
  5. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/planner.py +445 -90
  6. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/simulator.py +20 -0
  7. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/tuning.py +15 -0
  8. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/visualization.py +48 -0
  9. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
  10. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
  11. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
  12. pyrddlgym_jax-2.0/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +19 -0
  13. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
  14. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
  15. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
  16. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
  17. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
  18. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
  19. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/PKG-INFO +1 -1
  20. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/setup.py +1 -1
  21. pyrddlgym_jax-1.3/pyRDDLGym_jax/__init__.py +0 -1
  22. pyrddlgym_jax-1.3/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -18
  23. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/LICENSE +0 -0
  24. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/README.md +0 -0
  25. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/__init__.py +0 -0
  26. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  27. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  28. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/entry_point.py +0 -0
  29. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/__init__.py +0 -0
  30. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  31. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  32. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  33. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  34. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  35. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  36. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  37. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  38. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  39. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  40. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  41. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  42. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  43. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  44. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  45. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  46. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  47. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  48. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  49. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  50. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  51. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  52. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  53. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  54. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  55. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  56. {pyrddlgym_jax-1.3 → pyrddlgym_jax-2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.3
3
+ Version: 2.0
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -0,0 +1 @@
1
+ __version__ = '2.0'
@@ -1,3 +1,18 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # ***********************************************************************
14
+
15
+
1
16
  from functools import partial
2
17
  import traceback
3
18
  from typing import Any, Callable, Dict, List, Optional
@@ -524,7 +539,7 @@ class JaxRDDLCompiler:
524
539
  _jax_wrapped_single_step_policy,
525
540
  in_axes=(0, None, None, None, 0, None)
526
541
  )(keys, policy_params, hyperparams, step, subs, model_params)
527
- model_params = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_params)
542
+ model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
528
543
  carry = (key, policy_params, hyperparams, subs, model_params)
529
544
  return carry, log
530
545
 
@@ -1,4 +1,31 @@
1
- from typing import Optional, Set
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Petersen, Felix, Christian Borgelt, Hilde Kuehne, and Oliver Deussen. "Learning with
14
+ # algorithmic supervision via continuous relaxations." Advances in Neural Information
15
+ # Processing Systems 34 (2021): 16520-16531.
16
+ #
17
+ # [3] Agustsson, Eirikur, and Lucas Theis. "Universally quantized neural compression."
18
+ # Advances in neural information processing systems 33 (2020): 12367-12376.
19
+ #
20
+ # [4] Gupta, Madan M., and J11043360726 Qi. "Theory of T-norms and fuzzy inference
21
+ # methods." Fuzzy sets and systems 40, no. 3 (1991): 431-450.
22
+ #
23
+ # [5] Jang, Eric, Shixiang Gu, and Ben Poole. "Categorical Reparametrization with
24
+ # Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
25
+ # OpenReview. net, 2017.
26
+ #
27
+ # ***********************************************************************
28
+
2
29
 
3
30
  import jax
4
31
  import jax.numpy as jnp
@@ -759,14 +786,14 @@ class FuzzyLogic(Logic):
759
786
 
760
787
  def __str__(self) -> str:
761
788
  return (f'model relaxation:\n'
762
- f' tnorm ={str(self.tnorm)}\n'
763
- f' complement ={str(self.complement)}\n'
764
- f' comparison ={str(self.comparison)}\n'
765
- f' sampling ={str(self.sampling)}\n'
766
- f' rounding ={str(self.rounding)}\n'
767
- f' control ={str(self.control)}\n'
768
- f' underflow_tol ={self.eps}\n'
769
- f' use_64_bit ={self.use64bit}')
789
+ f' tnorm ={str(self.tnorm)}\n'
790
+ f' complement ={str(self.complement)}\n'
791
+ f' comparison ={str(self.comparison)}\n'
792
+ f' sampling ={str(self.sampling)}\n'
793
+ f' rounding ={str(self.rounding)}\n'
794
+ f' control ={str(self.control)}\n'
795
+ f' underflow_tol={self.eps}\n'
796
+ f' use_64_bit ={self.use64bit}\n')
770
797
 
771
798
  def summarize_hyperparameters(self) -> None:
772
799
  print(self.__str__())
@@ -1,12 +1,43 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # RELEVANT SOURCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
14
+ # Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
15
+ # the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
16
+ #
17
+ # [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
18
+ # reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
19
+ # AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
20
+ #
21
+ # [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
22
+ # nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
23
+ #
24
+ # [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
25
+ # policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
26
+ # Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
27
+ #
28
+ # ***********************************************************************
29
+
30
+
1
31
  from ast import literal_eval
2
32
  from collections import deque
3
33
  import configparser
4
34
  from enum import Enum
35
+ from functools import partial
5
36
  import os
6
37
  import sys
7
38
  import time
8
39
  import traceback
9
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
40
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
10
41
 
11
42
  import haiku as hk
12
43
  import jax
@@ -163,7 +194,20 @@ def _load_config(config, args):
163
194
  del planner_args['optimizer']
164
195
  else:
165
196
  planner_args['optimizer'] = optimizer
166
-
197
+
198
+ # pgpe optimizer
199
+ pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
200
+ pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
201
+ if pgpe_method is not None:
202
+ if 'optimizer' in pgpe_kwargs:
203
+ pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
204
+ if pgpe_optimizer is None:
205
+ raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
206
+ del pgpe_kwargs['optimizer']
207
+ else:
208
+ pgpe_kwargs['optimizer'] = pgpe_optimizer
209
+ planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
210
+
167
211
  # optimize call RNG key
168
212
  planner_key = train_args.get('key', None)
169
213
  if planner_key is not None:
@@ -469,16 +513,16 @@ class JaxStraightLinePlan(JaxPlan):
469
513
  bounds = '\n '.join(
470
514
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
471
515
  return (f'policy hyper-parameters:\n'
472
- f' initializer ={self._initializer_base}\n'
473
- f'constraint-sat strategy (simple):\n'
474
- f' parsed_action_bounds =\n {bounds}\n'
475
- f' wrap_sigmoid ={self._wrap_sigmoid}\n'
476
- f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
477
- f' wrap_non_bool ={self._wrap_non_bool}\n'
478
- f'constraint-sat strategy (complex):\n'
479
- f' wrap_softmax ={self._wrap_softmax}\n'
480
- f' use_new_projection ={self._use_new_projection}\n'
481
- f' max_projection_iters ={self._max_constraint_iter}')
516
+ f' initializer={self._initializer_base}\n'
517
+ f' constraint-sat strategy (simple):\n'
518
+ f' parsed_action_bounds =\n {bounds}\n'
519
+ f' wrap_sigmoid ={self._wrap_sigmoid}\n'
520
+ f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
521
+ f' wrap_non_bool ={self._wrap_non_bool}\n'
522
+ f' constraint-sat strategy (complex):\n'
523
+ f' wrap_softmax ={self._wrap_softmax}\n'
524
+ f' use_new_projection ={self._use_new_projection}\n'
525
+ f' max_projection_iters={self._max_constraint_iter}\n')
482
526
 
483
527
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
484
528
  _bounds: Bounds,
@@ -856,15 +900,16 @@ class JaxDeepReactivePolicy(JaxPlan):
856
900
  bounds = '\n '.join(
857
901
  map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
858
902
  return (f'policy hyper-parameters:\n'
859
- f' topology ={self._topology}\n'
860
- f' activation_fn ={self._activations[0].__name__}\n'
861
- f' initializer ={type(self._initializer_base).__name__}\n'
862
- f' apply_input_norm ={self._normalize}\n'
863
- f' input_norm_layerwise={self._normalize_per_layer}\n'
864
- f' input_norm_args ={self._normalizer_kwargs}\n'
865
- f'constraint-sat strategy:\n'
866
- f' parsed_action_bounds=\n {bounds}\n'
867
- f' wrap_non_bool ={self._wrap_non_bool}')
903
+ f' topology ={self._topology}\n'
904
+ f' activation_fn={self._activations[0].__name__}\n'
905
+ f' initializer ={type(self._initializer_base).__name__}\n'
906
+ f' input norm:\n'
907
+ f' apply_input_norm ={self._normalize}\n'
908
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
909
+ f' input_norm_args ={self._normalizer_kwargs}\n'
910
+ f' constraint-sat strategy:\n'
911
+ f' parsed_action_bounds=\n {bounds}\n'
912
+ f' wrap_non_bool ={self._wrap_non_bool}\n')
868
913
 
869
914
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
870
915
  _bounds: Bounds,
@@ -1090,10 +1135,11 @@ class JaxDeepReactivePolicy(JaxPlan):
1090
1135
 
1091
1136
 
1092
1137
  # ***********************************************************************
1093
- # ALL VERSIONS OF JAX PLANNER
1138
+ # SUPPORTING FUNCTIONS
1094
1139
  #
1095
- # - simple gradient descent based planner
1096
- # - more stable but slower line search based planner
1140
+ # - smoothed mean calculation
1141
+ # - planner status
1142
+ # - stopping criteria
1097
1143
  #
1098
1144
  # ***********************************************************************
1099
1145
 
@@ -1167,6 +1213,264 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1167
1213
  return f'No improvement for {self.patience} iterations'
1168
1214
 
1169
1215
 
1216
+ # ***********************************************************************
1217
+ # PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
1218
+ #
1219
+ # - simple Gaussian PGPE
1220
+ #
1221
+ # ***********************************************************************
1222
+
1223
+
1224
+ class PGPE:
1225
+ """Base class for all PGPE strategies."""
1226
+
1227
+ def __init__(self) -> None:
1228
+ self._initializer = None
1229
+ self._update = None
1230
+
1231
+ @property
1232
+ def initialize(self):
1233
+ return self._initializer
1234
+
1235
+ @property
1236
+ def update(self):
1237
+ return self._update
1238
+
1239
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1240
+ raise NotImplementedError
1241
+
1242
+
1243
+ class GaussianPGPE(PGPE):
1244
+ '''PGPE with a Gaussian parameter distribution.'''
1245
+
1246
+ def __init__(self, batch_size: int=1,
1247
+ init_sigma: float=1.0,
1248
+ sigma_range: Tuple[float, float]=(1e-5, 1e5),
1249
+ scale_reward: bool=True,
1250
+ super_symmetric: bool=True,
1251
+ super_symmetric_accurate: bool=True,
1252
+ optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1253
+ optimizer_kwargs_mu: Optional[Kwargs]=None,
1254
+ optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1255
+ '''Creates a new Gaussian PGPE planner.
1256
+
1257
+ :param batch_size: how many policy parameters to sample per optimization step
1258
+ :param init_sigma: initial standard deviation of Gaussian
1259
+ :param sigma_range: bounds to constrain standard deviation
1260
+ :param scale_reward: whether to apply reward scaling as in the paper
1261
+ :param super_symmetric: whether to use super-symmetric sampling as in the paper
1262
+ :param super_symmetric_accurate: whether to use the accurate formula for super-
1263
+ symmetric sampling or the simplified but biased formula
1264
+ :param optimizer: a factory for an optax SGD algorithm
1265
+ :param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
1266
+ factory for the mean optimizer
1267
+ :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1268
+ factory for the standard deviation optimizer
1269
+ '''
1270
+ super().__init__()
1271
+
1272
+ self.batch_size = batch_size
1273
+ self.init_sigma = init_sigma
1274
+ self.sigma_range = sigma_range
1275
+ self.scale_reward = scale_reward
1276
+ self.super_symmetric = super_symmetric
1277
+ self.super_symmetric_accurate = super_symmetric_accurate
1278
+
1279
+ # set optimizers
1280
+ if optimizer_kwargs_mu is None:
1281
+ optimizer_kwargs_mu = {'learning_rate': 0.1}
1282
+ self.optimizer_kwargs_mu = optimizer_kwargs_mu
1283
+ if optimizer_kwargs_sigma is None:
1284
+ optimizer_kwargs_sigma = {'learning_rate': 0.1}
1285
+ self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1286
+ self.optimizer_name = optimizer
1287
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1288
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1289
+ self.optimizers = (mu_optimizer, sigma_optimizer)
1290
+
1291
+ def __str__(self) -> str:
1292
+ return (f'PGPE hyper-parameters:\n'
1293
+ f' method ={self.__class__.__name__}\n'
1294
+ f' batch_size ={self.batch_size}\n'
1295
+ f' init_sigma ={self.init_sigma}\n'
1296
+ f' sigma_range ={self.sigma_range}\n'
1297
+ f' scale_reward ={self.scale_reward}\n'
1298
+ f' super_symmetric={self.super_symmetric}\n'
1299
+ f' accurate ={self.super_symmetric_accurate}\n'
1300
+ f' optimizer ={self.optimizer_name}\n'
1301
+ f' optimizer_kwargs:\n'
1302
+ f' mu ={self.optimizer_kwargs_mu}\n'
1303
+ f' sigma={self.optimizer_kwargs_sigma}\n'
1304
+ )
1305
+
1306
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1307
+ MIN_NORM = 1e-5
1308
+ sigma0 = self.init_sigma
1309
+ sigma_range = self.sigma_range
1310
+ scale_reward = self.scale_reward
1311
+ super_symmetric = self.super_symmetric
1312
+ super_symmetric_accurate = self.super_symmetric_accurate
1313
+ batch_size = self.batch_size
1314
+ optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1315
+
1316
+ # initializer
1317
+ def _jax_wrapped_pgpe_init(key, policy_params):
1318
+ mu = policy_params
1319
+ sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
1320
+ pgpe_params = (mu, sigma)
1321
+ pgpe_opt_state = tuple(opt.init(param)
1322
+ for (opt, param) in zip(optimizers, pgpe_params))
1323
+ return pgpe_params, pgpe_opt_state
1324
+
1325
+ self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1326
+
1327
+ # parameter sampling functions
1328
+ def _jax_wrapped_mu_noise(key, sigma):
1329
+ return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1330
+
1331
+ def _jax_wrapped_epsilon_star(sigma, epsilon):
1332
+ c1, c2, c3 = -0.06655, -0.9706, 0.124
1333
+ phi = 0.67449 * sigma
1334
+ a = (sigma - jnp.abs(epsilon)) / sigma
1335
+ if super_symmetric_accurate:
1336
+ aa = jnp.abs(a)
1337
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1338
+ a <= 0,
1339
+ jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1340
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1341
+ )
1342
+ else:
1343
+ epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1344
+ return epsilon_star
1345
+
1346
+ def _jax_wrapped_sample_params(key, mu, sigma):
1347
+ keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1348
+ keys_pytree = jax.tree_util.tree_unflatten(
1349
+ treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1350
+ epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1351
+ p1 = jax.tree_map(jnp.add, mu, epsilon)
1352
+ p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1353
+ if super_symmetric:
1354
+ epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1355
+ p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1356
+ p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1357
+ else:
1358
+ epsilon_star, p3, p4 = epsilon, p1, p2
1359
+ return (p1, p2, p3, p4), (epsilon, epsilon_star)
1360
+
1361
+ # policy gradient update functions
1362
+ def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1363
+ if super_symmetric:
1364
+ if scale_reward:
1365
+ scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1366
+ scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1367
+ else:
1368
+ scale1 = scale2 = 1.0
1369
+ r_mu1 = (r1 - r2) / (2 * scale1)
1370
+ r_mu2 = (r3 - r4) / (2 * scale2)
1371
+ grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1372
+ else:
1373
+ if scale_reward:
1374
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1375
+ else:
1376
+ scale = 1.0
1377
+ r_mu = (r1 - r2) / (2 * scale)
1378
+ grad = -r_mu * epsilon
1379
+ return grad
1380
+
1381
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1382
+ if super_symmetric:
1383
+ mask = r1 + r2 >= r3 + r4
1384
+ epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1385
+ s = epsilon_tau * epsilon_tau / sigma - sigma
1386
+ if scale_reward:
1387
+ scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1388
+ else:
1389
+ scale = 1.0
1390
+ r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1391
+ else:
1392
+ s = epsilon * epsilon / sigma - sigma
1393
+ if scale_reward:
1394
+ scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1395
+ else:
1396
+ scale = 1.0
1397
+ r_sigma = (r1 + r2) / (2 * scale)
1398
+ grad = -r_sigma * s
1399
+ return grad
1400
+
1401
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1402
+ policy_hyperparams, subs, model_params):
1403
+ key, subkey = random.split(key)
1404
+ (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1405
+ key, mu, sigma)
1406
+ r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1407
+ r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
1408
+ r_max = jnp.maximum(r_max, r1)
1409
+ r_max = jnp.maximum(r_max, r2)
1410
+ if super_symmetric:
1411
+ r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
1412
+ r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
1413
+ r_max = jnp.maximum(r_max, r3)
1414
+ r_max = jnp.maximum(r_max, r4)
1415
+ else:
1416
+ r3, r4 = r1, r2
1417
+ grad_mu = jax.tree_map(
1418
+ partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1419
+ epsilon, epsilon_star
1420
+ )
1421
+ grad_sigma = jax.tree_map(
1422
+ partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1423
+ epsilon, epsilon_star, sigma
1424
+ )
1425
+ return grad_mu, grad_sigma, r_max
1426
+
1427
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1428
+ policy_hyperparams, subs, model_params):
1429
+ mu, sigma = pgpe_params
1430
+ if batch_size == 1:
1431
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1432
+ key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1433
+ else:
1434
+ keys = random.split(key, num=batch_size)
1435
+ mu_grads, sigma_grads, r_maxs = jax.vmap(
1436
+ _jax_wrapped_pgpe_grad,
1437
+ in_axes=(0, None, None, None, None, None, None)
1438
+ )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1439
+ mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
1440
+ sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
1441
+ new_r_max = jnp.max(r_maxs)
1442
+ return mu_grad, sigma_grad, new_r_max
1443
+
1444
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1445
+ policy_hyperparams, subs, model_params,
1446
+ pgpe_opt_state):
1447
+ mu, sigma = pgpe_params
1448
+ mu_state, sigma_state = pgpe_opt_state
1449
+ mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1450
+ key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1451
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1452
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1453
+ sigma_grad, sigma_state, params=sigma)
1454
+ new_mu = optax.apply_updates(mu, mu_updates)
1455
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1456
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1457
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1458
+ new_pgpe_params = (new_mu, new_sigma)
1459
+ new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1460
+ policy_params = new_mu
1461
+ return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1462
+
1463
+ self._update = jax.jit(_jax_wrapped_pgpe_update)
1464
+
1465
+
1466
+ # ***********************************************************************
1467
+ # ALL VERSIONS OF JAX PLANNER
1468
+ #
1469
+ # - simple gradient descent based planner
1470
+ #
1471
+ # ***********************************************************************
1472
+
1473
+
1170
1474
  class JaxBackpropPlanner:
1171
1475
  '''A class for optimizing an action sequence in the given RDDL MDP using
1172
1476
  gradient descent.'''
@@ -1183,6 +1487,7 @@ class JaxBackpropPlanner:
1183
1487
  clip_grad: Optional[float]=None,
1184
1488
  line_search_kwargs: Optional[Kwargs]=None,
1185
1489
  noise_kwargs: Optional[Kwargs]=None,
1490
+ pgpe: Optional[PGPE]=GaussianPGPE(),
1186
1491
  logic: Logic=FuzzyLogic(),
1187
1492
  use_symlog_reward: bool=False,
1188
1493
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
@@ -1213,6 +1518,7 @@ class JaxBackpropPlanner:
1213
1518
  :param line_search_kwargs: parameters to pass to optional line search
1214
1519
  method to scale learning rate
1215
1520
  :param noise_kwargs: parameters of optional gradient noise
1521
+ :param pgpe: optional policy gradient to run alongside the planner
1216
1522
  :param logic: a subclass of Logic for mapping exact mathematical
1217
1523
  operations to their differentiable counterparts
1218
1524
  :param use_symlog_reward: whether to use the symlog transform on the
@@ -1251,6 +1557,8 @@ class JaxBackpropPlanner:
1251
1557
  self.clip_grad = clip_grad
1252
1558
  self.line_search_kwargs = line_search_kwargs
1253
1559
  self.noise_kwargs = noise_kwargs
1560
+ self.pgpe = pgpe
1561
+ self.use_pgpe = pgpe is not None
1254
1562
 
1255
1563
  # set optimizer
1256
1564
  try:
@@ -1355,24 +1663,25 @@ r"""
1355
1663
  f' line_search_kwargs={self.line_search_kwargs}\n'
1356
1664
  f' noise_kwargs ={self.noise_kwargs}\n'
1357
1665
  f' batch_size_train ={self.batch_size_train}\n'
1358
- f' batch_size_test ={self.batch_size_test}')
1359
- result += '\n' + str(self.plan)
1360
- result += '\n' + str(self.logic)
1666
+ f' batch_size_test ={self.batch_size_test}\n')
1667
+ result += str(self.plan)
1668
+ if self.use_pgpe:
1669
+ result += str(self.pgpe)
1670
+ result += str(self.logic)
1361
1671
 
1362
1672
  # print model relaxation information
1363
- if not self.compiled.model_params:
1364
- return result
1365
- result += '\n' + ('Some RDDL operations are non-differentiable '
1366
- 'and will be approximated as follows:' + '\n')
1367
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1368
- for info in self.compiled.model_parameter_info().values():
1369
- rddl_op = info['rddl_op']
1370
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1371
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1372
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1373
- result += (f' {rddl_op}:\n'
1374
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1375
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1673
+ if self.compiled.model_params:
1674
+ result += ('Some RDDL operations are non-differentiable '
1675
+ 'and will be approximated as follows:' + '\n')
1676
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1677
+ for info in self.compiled.model_parameter_info().values():
1678
+ rddl_op = info['rddl_op']
1679
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1680
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1681
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1682
+ result += (f' {rddl_op}:\n'
1683
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1684
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1376
1685
  return result
1377
1686
 
1378
1687
  def summarize_hyperparameters(self) -> None:
@@ -1438,6 +1747,15 @@ r"""
1438
1747
  # optimization
1439
1748
  self.update = self._jax_update(train_loss)
1440
1749
  self.check_zero_grad = self._jax_check_zero_gradients()
1750
+
1751
+ # pgpe option
1752
+ if self.use_pgpe:
1753
+ loss_fn = self._jax_loss(rollouts=test_rollouts)
1754
+ self.pgpe.compile(
1755
+ loss_fn=loss_fn,
1756
+ projection=self.plan.projection,
1757
+ real_dtype=self.test_compiled.REAL
1758
+ )
1441
1759
 
1442
1760
  def _jax_return(self, use_symlog):
1443
1761
  gamma = self.rddl.discount
@@ -1646,7 +1964,7 @@ r"""
1646
1964
  return grad
1647
1965
 
1648
1966
  return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1649
-
1967
+
1650
1968
  # ===========================================================================
1651
1969
  # OPTIMIZE API
1652
1970
  # ===========================================================================
@@ -1819,7 +2137,17 @@ r"""
1819
2137
  policy_params = guess
1820
2138
  opt_state = self.optimizer.init(policy_params)
1821
2139
  opt_aux = {}
1822
-
2140
+
2141
+ # initialize pgpe parameters
2142
+ if self.use_pgpe:
2143
+ pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
2144
+ rolling_pgpe_loss = RollingMean(test_rolling_window)
2145
+ else:
2146
+ pgpe_params, pgpe_opt_state = None, None
2147
+ rolling_pgpe_loss = None
2148
+ total_pgpe_it = 0
2149
+ r_max = -jnp.inf
2150
+
1823
2151
  # ======================================================================
1824
2152
  # INITIALIZATION OF RUNNING STATISTICS
1825
2153
  # ======================================================================
@@ -1860,17 +2188,47 @@ r"""
1860
2188
 
1861
2189
  # update the parameters of the plan
1862
2190
  key, subkey = random.split(key)
1863
- (policy_params, converged, opt_state, opt_aux,
1864
- train_loss, train_log, model_params) = \
1865
- self.update(subkey, policy_params, policy_hyperparams,
1866
- train_subs, model_params, opt_state, opt_aux)
1867
-
2191
+ (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2192
+ model_params) = self.update(subkey, policy_params, policy_hyperparams,
2193
+ train_subs, model_params, opt_state, opt_aux)
2194
+ test_loss, (test_log, model_params_test) = self.test_loss(
2195
+ subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2196
+ test_loss_smooth = rolling_test_loss.update(test_loss)
2197
+
2198
+ # pgpe update of the plan
2199
+ pgpe_improve = False
2200
+ if self.use_pgpe:
2201
+ key, subkey = random.split(key)
2202
+ pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2203
+ self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2204
+ test_subs, model_params, pgpe_opt_state)
2205
+ pgpe_loss, _ = self.test_loss(
2206
+ subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2207
+ pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2208
+ pgpe_return = -pgpe_loss_smooth
2209
+
2210
+ # replace with PGPE if it reaches a new minimum or train loss invalid
2211
+ if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2212
+ policy_params = pgpe_param
2213
+ test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2214
+ converged = pgpe_converged
2215
+ pgpe_improve = True
2216
+ total_pgpe_it += 1
2217
+ else:
2218
+ pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2219
+
2220
+ # evaluate test losses and record best plan so far
2221
+ if test_loss_smooth < best_loss:
2222
+ best_params, best_loss, best_grad = \
2223
+ policy_params, test_loss_smooth, train_log['grad']
2224
+ last_iter_improve = it
2225
+
1868
2226
  # ==================================================================
1869
2227
  # STATUS CHECKS AND LOGGING
1870
2228
  # ==================================================================
1871
2229
 
1872
2230
  # no progress
1873
- if self.check_zero_grad(train_log['grad']):
2231
+ if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
1874
2232
  status = JaxPlannerStatus.NO_PROGRESS
1875
2233
 
1876
2234
  # constraint satisfaction problem
@@ -1882,21 +2240,14 @@ r"""
1882
2240
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1883
2241
 
1884
2242
  # numerical error
1885
- if not np.isfinite(train_loss):
1886
- raise_warning(
1887
- f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
2243
+ if self.use_pgpe:
2244
+ invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
2245
+ else:
2246
+ invalid_loss = not np.isfinite(train_loss)
2247
+ if invalid_loss:
2248
+ raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
1888
2249
  status = JaxPlannerStatus.INVALID_GRADIENT
1889
2250
 
1890
- # evaluate test losses and record best plan so far
1891
- test_loss, (log, model_params_test) = self.test_loss(
1892
- subkey, policy_params, policy_hyperparams,
1893
- test_subs, model_params_test)
1894
- test_loss = rolling_test_loss.update(test_loss)
1895
- if test_loss < best_loss:
1896
- best_params, best_loss, best_grad = \
1897
- policy_params, test_loss, train_log['grad']
1898
- last_iter_improve = it
1899
-
1900
2251
  # reached computation budget
1901
2252
  elapsed = time.time() - start_time - elapsed_outside_loop
1902
2253
  if elapsed >= train_seconds:
@@ -1910,11 +2261,14 @@ r"""
1910
2261
  'status': status,
1911
2262
  'iteration': it,
1912
2263
  'train_return':-train_loss,
1913
- 'test_return':-test_loss,
2264
+ 'test_return':-test_loss_smooth,
1914
2265
  'best_return':-best_loss,
2266
+ 'pgpe_return': pgpe_return,
1915
2267
  'params': policy_params,
1916
2268
  'best_params': best_params,
2269
+ 'pgpe_params': pgpe_params,
1917
2270
  'last_iteration_improved': last_iter_improve,
2271
+ 'pgpe_improved': pgpe_improve,
1918
2272
  'grad': train_log['grad'],
1919
2273
  'best_grad': best_grad,
1920
2274
  'updates': train_log['updates'],
@@ -1923,7 +2277,7 @@ r"""
1923
2277
  'model_params': model_params,
1924
2278
  'progress': progress_percent,
1925
2279
  'train_log': train_log,
1926
- **log
2280
+ **test_log
1927
2281
  }
1928
2282
 
1929
2283
  # stopping condition reached
@@ -1934,9 +2288,9 @@ r"""
1934
2288
  if print_progress:
1935
2289
  iters.n = progress_percent
1936
2290
  iters.set_description(
1937
- f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1938
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1939
- f'{status.value} status'
2291
+ f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2292
+ f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2293
+ f'{status.value} status / {total_pgpe_it:6} pgpe'
1940
2294
  )
1941
2295
 
1942
2296
  # dash-board
@@ -1955,7 +2309,7 @@ r"""
1955
2309
  # ======================================================================
1956
2310
  # POST-PROCESSING AND CLEANUP
1957
2311
  # ======================================================================
1958
-
2312
+
1959
2313
  # release resources
1960
2314
  if print_progress:
1961
2315
  iters.close()
@@ -1967,7 +2321,7 @@ r"""
1967
2321
  messages.update(JaxRDDLCompiler.get_error_messages(error_code))
1968
2322
  if messages:
1969
2323
  messages = '\n'.join(messages)
1970
- raise_warning('The JAX compiler encountered the following '
2324
+ raise_warning('JAX compiler encountered the following '
1971
2325
  'error(s) in the original RDDL formulation '
1972
2326
  f'during test evaluation:\n{messages}', 'red')
1973
2327
 
@@ -1975,14 +2329,14 @@ r"""
1975
2329
  if print_summary:
1976
2330
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1977
2331
  diagnosis = self._perform_diagnosis(
1978
- last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
2332
+ last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
1979
2333
  print(f'summary of optimization:\n'
1980
- f' status_code ={status}\n'
1981
- f' time_elapsed ={elapsed}\n'
2334
+ f' status ={status}\n'
2335
+ f' time ={elapsed:.6f} sec.\n'
1982
2336
  f' iterations ={it}\n'
1983
- f' best_objective={-best_loss}\n'
1984
- f' best_grad_norm={grad_norm}\n'
1985
- f' diagnosis: {diagnosis}\n')
2337
+ f' best objective={-best_loss:.6f}\n'
2338
+ f' best grad norm={grad_norm}\n'
2339
+ f'diagnosis: {diagnosis}\n')
1986
2340
 
1987
2341
  def _perform_diagnosis(self, last_iter_improve,
1988
2342
  train_return, test_return, best_return, grad_norm):
@@ -2002,23 +2356,24 @@ r"""
2002
2356
  if last_iter_improve <= 1:
2003
2357
  if grad_is_zero:
2004
2358
  return termcolor.colored(
2005
- '[FAILURE] no progress was made, '
2006
- f'and max grad norm {max_grad_norm:.6f} is zero: '
2007
- 'solver likely stuck in a plateau.', 'red')
2359
+ '[FAILURE] no progress was made '
2360
+ f'and max grad norm {max_grad_norm:.6f} was zero: '
2361
+ 'the solver was likely stuck in a plateau.', 'red')
2008
2362
  else:
2009
2363
  return termcolor.colored(
2010
- '[FAILURE] no progress was made, '
2011
- f'but max grad norm {max_grad_norm:.6f} is non-zero: '
2012
- 'likely poor learning rate or other hyper-parameter.', 'red')
2364
+ '[FAILURE] no progress was made '
2365
+ f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2366
+ 'the learning rate or other hyper-parameters were likely suboptimal.',
2367
+ 'red')
2013
2368
 
2014
2369
  # model is likely poor IF:
2015
2370
  # 1. the train and test return disagree
2016
2371
  if not (validation_error < 20):
2017
2372
  return termcolor.colored(
2018
- '[WARNING] progress was made, '
2019
- f'but relative train-test error {validation_error:.6f} is high: '
2020
- 'likely poor model relaxation around the solution, '
2021
- 'or the batch size is too small.', 'yellow')
2373
+ '[WARNING] progress was made '
2374
+ f'but relative train-test error {validation_error:.6f} was high: '
2375
+ 'model relaxation around the solution was poor '
2376
+ 'or the batch size was too small.', 'yellow')
2022
2377
 
2023
2378
  # model likely did not converge IF:
2024
2379
  # 1. the max grad relative to the return is high
@@ -2026,15 +2381,15 @@ r"""
2026
2381
  return_to_grad_norm = abs(best_return) / max_grad_norm
2027
2382
  if not (return_to_grad_norm > 1):
2028
2383
  return termcolor.colored(
2029
- '[WARNING] progress was made, '
2030
- f'but max grad norm {max_grad_norm:.6f} is high: '
2031
- 'likely the solution is not locally optimal, '
2032
- 'or the relaxed model is not smooth around the solution, '
2033
- 'or the batch size is too small.', 'yellow')
2384
+ '[WARNING] progress was made '
2385
+ f'but max grad norm {max_grad_norm:.6f} was high: '
2386
+ 'the solution was likely locally suboptimal, '
2387
+ 'or the relaxed model was not smooth around the solution, '
2388
+ 'or the batch size was too small.', 'yellow')
2034
2389
 
2035
2390
  # likely successful
2036
2391
  return termcolor.colored(
2037
- '[SUCCESS] planner has converged successfully '
2392
+ '[SUCCESS] solver converged successfully '
2038
2393
  '(note: not all potential problems can be ruled out).', 'green')
2039
2394
 
2040
2395
  def get_action(self, key: random.PRNGKey,
@@ -1,3 +1,23 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Taitler, Ayal, Michael Gimelfarb, Jihwan Jeong, Sriram Gopalakrishnan, Martin
14
+ # Mladenov, Xiaotian Liu, and Scott Sanner. "pyRDDLGym: From RDDL to Gym Environments."
15
+ # In PRL Workshop Series {\textendash} Bridging the Gap Between AI Planning and
16
+ # Reinforcement Learning.
17
+ #
18
+ # ***********************************************************************
19
+
20
+
1
21
  import time
2
22
  from typing import Dict, Optional
3
23
 
@@ -1,3 +1,18 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # ***********************************************************************
14
+
15
+
1
16
  import csv
2
17
  import datetime
3
18
  import threading
@@ -1,3 +1,18 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # ***********************************************************************
14
+
15
+
1
16
  import ast
2
17
  import os
3
18
  from datetime import datetime
@@ -61,6 +76,7 @@ class JaxPlannerDashboard:
61
76
  self.xticks = {}
62
77
  self.test_return = {}
63
78
  self.train_return = {}
79
+ self.pgpe_return = {}
64
80
  self.return_dist = {}
65
81
  self.return_dist_ticks = {}
66
82
  self.return_dist_last_progress = {}
@@ -299,6 +315,9 @@ class JaxPlannerDashboard:
299
315
  dbc.Col(Graph(id='train-return-graph'), width=6),
300
316
  dbc.Col(Graph(id='test-return-graph'), width=6),
301
317
  ]),
318
+ dbc.Row([
319
+ dbc.Col(Graph(id='pgpe-return-graph'), width=6)
320
+ ]),
302
321
  dbc.Row([
303
322
  Graph(id='dist-return-graph')
304
323
  ])
@@ -661,6 +680,33 @@ class JaxPlannerDashboard:
661
680
  )
662
681
  return fig
663
682
 
683
+ @app.callback(
684
+ Output('pgpe-return-graph', 'figure'),
685
+ [Input('interval', 'n_intervals'),
686
+ Input('trigger-experiment-check', 'children'),
687
+ Input('tabs-main', 'active_tab')]
688
+ )
689
+ def update_pgpe_return_graph(n, trigger, active_tab):
690
+ if active_tab != 'tab-performance': return dash.no_update
691
+ fig = go.Figure()
692
+ for (row, checked) in self.checked.copy().items():
693
+ if checked:
694
+ fig.add_trace(go.Scatter(
695
+ x=self.xticks[row], y=self.pgpe_return[row],
696
+ name=f'id={row}',
697
+ mode='lines+markers',
698
+ marker=dict(size=3), line=dict(width=2)
699
+ ))
700
+ fig.update_layout(
701
+ title=dict(text="PGPE Return"),
702
+ xaxis=dict(title=dict(text="Training Iteration")),
703
+ yaxis=dict(title=dict(text="Cumulative Reward")),
704
+ font=dict(size=PLOT_AXES_FONT_SIZE),
705
+ legend=dict(bgcolor='rgba(0,0,0,0)'),
706
+ template="plotly_white"
707
+ )
708
+ return fig
709
+
664
710
  @app.callback(
665
711
  Output('dist-return-graph', 'figure'),
666
712
  [Input('interval', 'n_intervals'),
@@ -1316,6 +1362,7 @@ class JaxPlannerDashboard:
1316
1362
  self.xticks[experiment_id] = []
1317
1363
  self.train_return[experiment_id] = []
1318
1364
  self.test_return[experiment_id] = []
1365
+ self.pgpe_return[experiment_id] = []
1319
1366
  self.return_dist_ticks[experiment_id] = []
1320
1367
  self.return_dist_last_progress[experiment_id] = 0
1321
1368
  self.return_dist[experiment_id] = []
@@ -1367,6 +1414,7 @@ class JaxPlannerDashboard:
1367
1414
  self.xticks[experiment_id].append(iteration)
1368
1415
  self.train_return[experiment_id].append(callback['train_return'])
1369
1416
  self.test_return[experiment_id].append(callback['best_return'])
1417
+ self.pgpe_return[experiment_id].append(callback['pgpe_return'])
1370
1418
 
1371
1419
  # data for return distributions
1372
1420
  progress = callback['progress']
@@ -1,8 +1,8 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 50}
4
- rounding_kwargs={'weight': 50}
5
- control_kwargs={'weight': 50}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
@@ -1,14 +1,14 @@
1
1
  [Model]
2
2
  logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 30}
4
- rounding_kwargs={'weight': 30}
5
- control_kwargs={'weight': 30}
3
+ comparison_kwargs={'weight': 20}
4
+ rounding_kwargs={'weight': 20}
5
+ control_kwargs={'weight': 20}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.002}
11
+ optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
  clip_grad=1.0
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -0,0 +1,19 @@
1
+ [Model]
2
+ logic='FuzzyLogic'
3
+ comparison_kwargs={'weight': 10}
4
+ rounding_kwargs={'weight': 10}
5
+ control_kwargs={'weight': 10}
6
+
7
+ [Optimizer]
8
+ method='JaxStraightLinePlan'
9
+ method_kwargs={}
10
+ optimizer='rmsprop'
11
+ optimizer_kwargs={'learning_rate': 0.03}
12
+ batch_size_train=1
13
+ batch_size_test=1
14
+ pgpe=None
15
+
16
+ [Training]
17
+ key=42
18
+ epochs=100000
19
+ train_seconds=360
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.0002}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.2}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.0003}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -12,6 +12,7 @@ optimizer_kwargs={'learning_rate': 0.1}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
14
  rollout_horizon=5
15
+ pgpe=None
15
16
 
16
17
  [Training]
17
18
  key=42
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.01}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ pgpe=None
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.3
3
+ Version: 2.0
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='1.3',
22
+ version='2.0',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -1 +0,0 @@
1
- __version__ = '1.3'
@@ -1,18 +0,0 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- comparison_kwargs={'weight': 50}
4
- rounding_kwargs={'weight': 50}
5
- control_kwargs={'weight': 50}
6
-
7
- [Optimizer]
8
- method='JaxStraightLinePlan'
9
- method_kwargs={}
10
- optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.03}
12
- batch_size_train=1
13
- batch_size_test=1
14
-
15
- [Training]
16
- key=42
17
- epochs=100000
18
- train_seconds=360
File without changes
File without changes
File without changes