pyRDDLGym-jax 2.1__tar.gz → 2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/PKG-INFO +25 -22
  2. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/README.md +24 -21
  3. pyrddlgym_jax-2.3/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/compiler.py +14 -8
  5. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/logic.py +118 -55
  6. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/planner.py +159 -76
  7. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/PKG-INFO +25 -22
  8. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/setup.py +1 -1
  9. pyrddlgym_jax-2.1/pyRDDLGym_jax/__init__.py +0 -1
  10. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/LICENSE +0 -0
  11. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/__init__.py +0 -0
  12. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  13. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  14. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/simulator.py +0 -0
  15. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/tuning.py +0 -0
  16. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/visualization.py +0 -0
  17. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/entry_point.py +0 -0
  18. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/__init__.py +0 -0
  19. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  20. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  21. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  22. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  23. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  24. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  25. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  26. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  27. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  28. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  29. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  30. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  32. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  33. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  34. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  35. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  36. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  37. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  38. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  39. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  40. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  41. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  42. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  43. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  44. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  45. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  46. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  47. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  48. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  49. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  50. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  51. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  52. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  53. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  54. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  55. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.1
3
+ Version: 2.3
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -58,8 +58,11 @@ Dynamic: summary
58
58
 
59
59
  Purpose:
60
60
 
61
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
62
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
61
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
62
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
63
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
64
+ 4. interactive dashboard for dyanmic visualization and debugging
65
+ 5. hybridization with parameter-exploring policy gradients.
63
66
 
64
67
  Some demos of solved problems by JaxPlan:
65
68
 
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
235
238
 
236
239
  ## Tuning the Planner
237
240
 
238
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
239
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
241
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
+
243
+ ```shell
244
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ ```
246
+
247
+ where:
248
+ - ``domain`` is the domain identifier as specified in rddlrepository
249
+ - ``instance`` is the instance identifier
250
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
251
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+
256
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
240
258
 
241
259
  ```ini
242
260
  [Model]
@@ -260,7 +278,7 @@ train_on_reset=True
260
278
 
261
279
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
262
280
 
263
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
281
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
264
282
 
265
283
  ```python
266
284
  import pyRDDLGym
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
292
310
  gp_iters=iters)
293
311
  tuning.tune(key=42, log_file='path/to/log.csv')
294
312
  ```
295
-
296
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
297
-
298
- ```shell
299
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
300
- ```
301
-
302
- where:
303
- - ``domain`` is the domain identifier as specified in rddlrepository
304
- - ``instance`` is the instance identifier
305
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
306
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
307
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
308
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
309
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
310
-
313
+
311
314
 
312
315
  ## Simulation
313
316
 
@@ -12,8 +12,11 @@
12
12
 
13
13
  Purpose:
14
14
 
15
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
16
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
15
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
16
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
17
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
18
+ 4. interactive dashboard for dyanmic visualization and debugging
19
+ 5. hybridization with parameter-exploring policy gradients.
17
20
 
18
21
  Some demos of solved problems by JaxPlan:
19
22
 
@@ -189,8 +192,23 @@ More documentation about this and other new features will be coming soon.
189
192
 
190
193
  ## Tuning the Planner
191
194
 
192
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
193
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
195
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
196
+
197
+ ```shell
198
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
199
+ ```
200
+
201
+ where:
202
+ - ``domain`` is the domain identifier as specified in rddlrepository
203
+ - ``instance`` is the instance identifier
204
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
205
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
206
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
207
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
208
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
209
+
210
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
211
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
194
212
 
195
213
  ```ini
196
214
  [Model]
@@ -214,7 +232,7 @@ train_on_reset=True
214
232
 
215
233
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
216
234
 
217
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
235
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
218
236
 
219
237
  ```python
220
238
  import pyRDDLGym
@@ -246,22 +264,7 @@ tuning = JaxParameterTuning(env=env,
246
264
  gp_iters=iters)
247
265
  tuning.tune(key=42, log_file='path/to/log.csv')
248
266
  ```
249
-
250
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
251
-
252
- ```shell
253
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
254
- ```
255
-
256
- where:
257
- - ``domain`` is the domain identifier as specified in rddlrepository
258
- - ``instance`` is the instance identifier
259
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
260
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
261
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
262
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
263
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
264
-
267
+
265
268
 
266
269
  ## Simulation
267
270
 
@@ -0,0 +1 @@
1
+ __version__ = '2.3'
@@ -1019,6 +1019,9 @@ class JaxRDDLCompiler:
1019
1019
  # UnnormDiscrete: complete (subclass uses Gumbel-softmax)
1020
1020
  # Discrete(p): complete (subclass uses Gumbel-softmax)
1021
1021
  # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1022
+ # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1023
+ # Binomial (subclass uses Gumbel-softmax or Normal approximation)
1024
+ # NegativeBinomial (subclass uses Poisson-Gamma mixture)
1022
1025
 
1023
1026
  # distributions which seem to support backpropagation (need more testing):
1024
1027
  # Beta
@@ -1026,11 +1029,8 @@ class JaxRDDLCompiler:
1026
1029
  # Gamma
1027
1030
  # ChiSquare
1028
1031
  # Dirichlet
1029
- # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1030
1032
 
1031
1033
  # distributions with incomplete reparameterization support (TODO):
1032
- # Binomial
1033
- # NegativeBinomial
1034
1034
  # Multinomial
1035
1035
 
1036
1036
  def _jax_random(self, expr, init_params):
@@ -1299,8 +1299,17 @@ class JaxRDDLCompiler:
1299
1299
  def _jax_negative_binomial(self, expr, init_params):
1300
1300
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1301
1301
  JaxRDDLCompiler._check_num_args(expr, 2)
1302
-
1303
1302
  arg_trials, arg_prob = expr.args
1303
+
1304
+ # if prob is non-fluent, always use the exact operation
1305
+ if self.compile_non_fluent_exact \
1306
+ and not self.traced.cached_is_fluent(arg_trials) \
1307
+ and not self.traced.cached_is_fluent(arg_prob):
1308
+ negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
1309
+ else:
1310
+ negbin_op = self.OPS['sampling']['NegativeBinomial']
1311
+ jax_op = negbin_op(expr.id, init_params)
1312
+
1304
1313
  jax_trials = self._jax(arg_trials, init_params)
1305
1314
  jax_prob = self._jax(arg_prob, init_params)
1306
1315
 
@@ -1308,11 +1317,8 @@ class JaxRDDLCompiler:
1308
1317
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1309
1318
  trials, key, err2, params = jax_trials(x, params, key)
1310
1319
  prob, key, err1, params = jax_prob(x, params, key)
1311
- trials = jnp.asarray(trials, dtype=self.REAL)
1312
- prob = jnp.asarray(prob, dtype=self.REAL)
1313
1320
  key, subkey = random.split(key)
1314
- dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1315
- sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1321
+ sample, params = jax_op(subkey, trials, prob, params)
1316
1322
  out_of_bounds = jnp.logical_not(jnp.all(
1317
1323
  (prob >= 0) & (prob <= 1) & (trials > 0)))
1318
1324
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -29,15 +29,27 @@
29
29
  #
30
30
  # ***********************************************************************
31
31
 
32
- from typing import Callable, Dict, Union
32
+ import traceback
33
+ from typing import Callable, Dict, Tuple, Union
33
34
 
34
35
  import jax
35
36
  import jax.numpy as jnp
36
37
  import jax.random as random
37
38
  import jax.scipy as scipy
38
39
 
40
+ from pyRDDLGym.core.debug.exception import raise_warning
39
41
 
40
- def enumerate_literals(shape, axis, dtype=jnp.int32):
42
+ # more robust approach - if user does not have this or broken try to continue
43
+ try:
44
+ from tensorflow_probability.substrates import jax as tfp
45
+ except Exception:
46
+ raise_warning('Failed to import tensorflow-probability: '
47
+ 'compilation of some probability distributions will fail.', 'red')
48
+ traceback.print_exc()
49
+ tfp = None
50
+
51
+
52
+ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
41
53
  literals = jnp.arange(shape[axis], dtype=dtype)
42
54
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
43
55
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -74,7 +86,7 @@ class Comparison:
74
86
  class SigmoidComparison(Comparison):
75
87
  '''Comparison operations approximated using sigmoid functions.'''
76
88
 
77
- def __init__(self, weight: float=10.0):
89
+ def __init__(self, weight: float=10.0) -> None:
78
90
  self.weight = weight
79
91
 
80
92
  # https://arxiv.org/abs/2110.05651
@@ -140,7 +152,7 @@ class Rounding:
140
152
  class SoftRounding(Rounding):
141
153
  '''Rounding operations approximated using soft operations.'''
142
154
 
143
- def __init__(self, weight: float=10.0):
155
+ def __init__(self, weight: float=10.0) -> None:
144
156
  self.weight = weight
145
157
 
146
158
  # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
@@ -291,7 +303,7 @@ class YagerTNorm(TNorm):
291
303
  '''Yager t-norm given by the expression
292
304
  (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
293
305
 
294
- def __init__(self, p=2.0):
306
+ def __init__(self, p: float=2.0) -> None:
295
307
  self.p = float(p)
296
308
 
297
309
  def norm(self, id, init_params):
@@ -339,6 +351,9 @@ class RandomSampling:
339
351
  def binomial(self, id, init_params, logic):
340
352
  raise NotImplementedError
341
353
 
354
+ def negative_binomial(self, id, init_params, logic):
355
+ raise NotImplementedError
356
+
342
357
  def geometric(self, id, init_params, logic):
343
358
  raise NotImplementedError
344
359
 
@@ -386,8 +401,7 @@ class SoftRandomSampling(RandomSampling):
386
401
  def _poisson_gumbel_softmax(self, id, init_params, logic):
387
402
  argmax_approx = logic.argmax(id, init_params)
388
403
  def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
389
- ks = jnp.arange(0, self.poisson_bins)
390
- ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
404
+ ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
391
405
  rate = rate[..., jnp.newaxis]
392
406
  log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
393
407
  Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
@@ -400,10 +414,7 @@ class SoftRandomSampling(RandomSampling):
400
414
  less_approx = logic.less(id, init_params)
401
415
  def _jax_wrapped_calc_poisson_exponential(key, rate, params):
402
416
  Exp1 = random.exponential(
403
- key=key,
404
- shape=(self.poisson_bins,) + jnp.shape(rate),
405
- dtype=logic.REAL
406
- )
417
+ key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
407
418
  delta_t = Exp1 / rate[jnp.newaxis, ...]
408
419
  times = jnp.cumsum(delta_t, axis=0)
409
420
  indicator, params = less_approx(times, 1.0, params)
@@ -411,72 +422,98 @@ class SoftRandomSampling(RandomSampling):
411
422
  return sample, params
412
423
  return _jax_wrapped_calc_poisson_exponential
413
424
 
425
+ # normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
426
+ def _poisson_normal_approx(self, logic):
427
+ def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
428
+ normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
429
+ sample = rate + jnp.sqrt(rate) * normal
430
+ return sample, params
431
+ return _jax_wrapped_calc_poisson_normal_approx
432
+
414
433
  def poisson(self, id, init_params, logic):
415
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
416
- sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
417
- sample = jnp.asarray(sample, dtype=logic.REAL)
418
- return sample, params
419
-
420
434
  if self.poisson_exp_method:
421
435
  _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
422
436
  id, init_params, logic)
423
437
  else:
424
438
  _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
425
439
  id, init_params, logic)
440
+ _jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
426
441
 
442
+ # for small rate use the Poisson process or gumbel-softmax reparameterization
443
+ # for large rate use the normal approximation
427
444
  def _jax_wrapped_calc_poisson_approx(key, rate, params):
428
-
429
- # determine if error of truncation at rate is acceptable
430
445
  if self.poisson_bins > 0:
431
446
  cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
432
- approx_cond = jax.lax.stop_gradient(
433
- jnp.min(cuml_prob) > self.poisson_min_cdf)
447
+ small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
448
+ small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
449
+ large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
450
+ sample = jnp.where(small_rate, small_sample, large_sample)
451
+ return sample, params
434
452
  else:
435
- approx_cond = False
436
-
437
- # for acceptable truncation use the approximation, use exact otherwise
438
- return jax.lax.cond(
439
- approx_cond,
440
- _jax_wrapped_calc_poisson_diff,
441
- _jax_wrapped_calc_poisson_exact,
442
- key, rate, params
443
- )
453
+ return _jax_wrapped_calc_poisson_normal(key, rate, params)
444
454
  return _jax_wrapped_calc_poisson_approx
445
455
 
446
- def binomial(self, id, init_params, logic):
447
- def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
448
- trials = jnp.asarray(trials, dtype=logic.REAL)
449
- prob = jnp.asarray(prob, dtype=logic.REAL)
450
- sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
451
- return sample, params
456
+ # normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
457
+ def _binomial_normal_approx(self, logic):
458
+ def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
459
+ normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
460
+ mean = trials * prob
461
+ std = jnp.sqrt(trials * prob * (1.0 - prob))
462
+ sample = mean + std * normal
463
+ return sample, params
464
+ return _jax_wrapped_calc_binomial_normal_approx
452
465
 
453
- # Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
454
- bernoulli_approx = self.bernoulli(id, init_params, logic)
455
- def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
456
- prob_full = jnp.broadcast_to(
457
- prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
458
- sample_bern, params = bernoulli_approx(key, prob_full, params)
459
- indices = jnp.arange(self.binomial_bins)[
460
- (jnp.newaxis,) * jnp.ndim(prob) + (...,)]
461
- mask = indices < trials[..., jnp.newaxis]
462
- sample = jnp.sum(sample_bern * mask, axis=-1)
463
- return sample, params
466
+ def _binomial_gumbel_softmax(self, id, init_params, logic):
467
+ argmax_approx = logic.argmax(id, init_params)
468
+ def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
469
+ ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
470
+ trials = trials[..., jnp.newaxis]
471
+ prob = prob[..., jnp.newaxis]
472
+ in_support = ks <= trials
473
+ ks = jnp.minimum(ks, trials)
474
+ log_prob = ((scipy.special.gammaln(trials + 1) -
475
+ scipy.special.gammaln(ks + 1) -
476
+ scipy.special.gammaln(trials - ks + 1)) +
477
+ ks * jnp.log(prob + logic.eps) +
478
+ (trials - ks) * jnp.log1p(-prob + logic.eps))
479
+ log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
480
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
481
+ sample = Gumbel01 + log_prob
482
+ return argmax_approx(sample, axis=-1, params=params)
483
+ return _jax_wrapped_calc_binomial_gumbel_softmax
464
484
 
465
- # for trials not too large use the Bernoulli relaxation, use exact otherwise
485
+ def binomial(self, id, init_params, logic):
486
+ _jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
487
+ _jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
488
+
489
+ # for small trials use the Bernoulli relaxation
490
+ # for large trials use the normal approximation
466
491
  def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
467
- return jax.lax.cond(
468
- jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
469
- _jax_wrapped_calc_binomial_sum,
470
- _jax_wrapped_calc_binomial_exact,
471
- key, trials, prob, params
472
- )
492
+ small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
493
+ small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
494
+ large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
495
+ sample = jnp.where(small_trials, small_sample, large_sample)
496
+ return sample, params
473
497
  return _jax_wrapped_calc_binomial_approx
474
498
 
499
+ # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
500
+ def negative_binomial(self, id, init_params, logic):
501
+ poisson_approx = self.poisson(id, init_params, logic)
502
+ def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
503
+ key, subkey = random.split(key)
504
+ trials = jnp.asarray(trials, dtype=logic.REAL)
505
+ Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
506
+ scale = (1.0 - prob) / prob
507
+ poisson_rate = scale * Gamma
508
+ return poisson_approx(subkey, poisson_rate, params)
509
+ return _jax_wrapped_calc_negative_binomial_approx
510
+
475
511
  def geometric(self, id, init_params, logic):
476
512
  approx_floor = logic.floor(id, init_params)
477
513
  def _jax_wrapped_calc_geometric_approx(key, prob, params):
478
514
  U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
479
- floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
515
+ floor, params = approx_floor(
516
+ jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
480
517
  sample = floor + 1
481
518
  return sample, params
482
519
  return _jax_wrapped_calc_geometric_approx
@@ -532,6 +569,14 @@ class Determinization(RandomSampling):
532
569
  def binomial(self, id, init_params, logic):
533
570
  return self._jax_wrapped_calc_binomial_determinized
534
571
 
572
+ @staticmethod
573
+ def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
574
+ sample = trials * ((1.0 / prob) - 1.0)
575
+ return sample, params
576
+
577
+ def negative_binomial(self, id, init_params, logic):
578
+ return self._jax_wrapped_calc_negative_binomial_determinized
579
+
535
580
  @staticmethod
536
581
  def _jax_wrapped_calc_geometric_determinized(key, prob, params):
537
582
  sample = 1.0 / prob
@@ -712,7 +757,8 @@ class Logic:
712
757
  'Discrete': self.discrete,
713
758
  'Poisson': self.poisson,
714
759
  'Geometric': self.geometric,
715
- 'Binomial': self.binomial
760
+ 'Binomial': self.binomial,
761
+ 'NegativeBinomial': self.negative_binomial
716
762
  }
717
763
  }
718
764
 
@@ -830,6 +876,9 @@ class Logic:
830
876
  def binomial(self, id, init_params):
831
877
  raise NotImplementedError
832
878
 
879
+ def negative_binomial(self, id, init_params):
880
+ raise NotImplementedError
881
+
833
882
 
834
883
  class ExactLogic(Logic):
835
884
  '''A class representing exact logic in JAX.'''
@@ -1005,6 +1054,17 @@ class ExactLogic(Logic):
1005
1054
  sample = jnp.asarray(sample, dtype=self.INT)
1006
1055
  return sample, params
1007
1056
  return _jax_wrapped_calc_binomial_exact
1057
+
1058
+ # note: for some reason tfp defines it as number of successes before trials failures
1059
+ # I will define it as the number of failures before trials successes
1060
+ def negative_binomial(self, id, init_params):
1061
+ def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
1062
+ trials = jnp.asarray(trials, dtype=self.REAL)
1063
+ prob = jnp.asarray(prob, dtype=self.REAL)
1064
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
1065
+ sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
1066
+ return sample, params
1067
+ return _jax_wrapped_calc_negative_binomial_exact
1008
1068
 
1009
1069
 
1010
1070
  class FuzzyLogic(Logic):
@@ -1234,6 +1294,9 @@ class FuzzyLogic(Logic):
1234
1294
 
1235
1295
  def binomial(self, id, init_params):
1236
1296
  return self.sampling.binomial(id, init_params, self)
1297
+
1298
+ def negative_binomial(self, id, init_params):
1299
+ return self.sampling.negative_binomial(id, init_params, self)
1237
1300
 
1238
1301
 
1239
1302
  # ===========================================================================
@@ -47,7 +47,9 @@ import jax.random as random
47
47
  import numpy as np
48
48
  import optax
49
49
  import termcolor
50
- from tqdm import tqdm
50
+ from tqdm import tqdm, TqdmWarning
51
+ import warnings
52
+ warnings.filterwarnings("ignore", category=TqdmWarning)
51
53
 
52
54
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
53
55
  from pyRDDLGym.core.debug.logger import Logger
@@ -1212,17 +1214,22 @@ class GaussianPGPE(PGPE):
1212
1214
  init_sigma: float=1.0,
1213
1215
  sigma_range: Tuple[float, float]=(1e-5, 1e5),
1214
1216
  scale_reward: bool=True,
1217
+ min_reward_scale: float=1e-5,
1215
1218
  super_symmetric: bool=True,
1216
1219
  super_symmetric_accurate: bool=True,
1217
1220
  optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1218
1221
  optimizer_kwargs_mu: Optional[Kwargs]=None,
1219
- optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1222
+ optimizer_kwargs_sigma: Optional[Kwargs]=None,
1223
+ start_entropy_coeff: float=1e-3,
1224
+ end_entropy_coeff: float=1e-8,
1225
+ max_kl_update: Optional[float]=None) -> None:
1220
1226
  '''Creates a new Gaussian PGPE planner.
1221
1227
 
1222
1228
  :param batch_size: how many policy parameters to sample per optimization step
1223
1229
  :param init_sigma: initial standard deviation of Gaussian
1224
1230
  :param sigma_range: bounds to constrain standard deviation
1225
1231
  :param scale_reward: whether to apply reward scaling as in the paper
1232
+ :param min_reward_scale: minimum reward scaling to avoid underflow
1226
1233
  :param super_symmetric: whether to use super-symmetric sampling as in the paper
1227
1234
  :param super_symmetric_accurate: whether to use the accurate formula for super-
1228
1235
  symmetric sampling or the simplified but biased formula
@@ -1231,6 +1238,9 @@ class GaussianPGPE(PGPE):
1231
1238
  factory for the mean optimizer
1232
1239
  :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1233
1240
  factory for the standard deviation optimizer
1241
+ :param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
1242
+ :param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
1243
+ :param max_kl_update: bound on kl-divergence between parameter updates
1234
1244
  '''
1235
1245
  super().__init__()
1236
1246
 
@@ -1238,8 +1248,13 @@ class GaussianPGPE(PGPE):
1238
1248
  self.init_sigma = init_sigma
1239
1249
  self.sigma_range = sigma_range
1240
1250
  self.scale_reward = scale_reward
1251
+ self.min_reward_scale = min_reward_scale
1241
1252
  self.super_symmetric = super_symmetric
1242
1253
  self.super_symmetric_accurate = super_symmetric_accurate
1254
+
1255
+ # entropy regularization penalty is decayed exponentially between these values
1256
+ self.start_entropy_coeff = start_entropy_coeff
1257
+ self.end_entropy_coeff = end_entropy_coeff
1243
1258
 
1244
1259
  # set optimizers
1245
1260
  if optimizer_kwargs_mu is None:
@@ -1249,36 +1264,62 @@ class GaussianPGPE(PGPE):
1249
1264
  optimizer_kwargs_sigma = {'learning_rate': 0.1}
1250
1265
  self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1251
1266
  self.optimizer_name = optimizer
1252
- mu_optimizer = optimizer(**optimizer_kwargs_mu)
1253
- sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1267
+ try:
1268
+ mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1269
+ sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1270
+ except Exception as _:
1271
+ raise_warning(
1272
+ f'Failed to inject hyperparameters into optax optimizer for PGPE, '
1273
+ 'rolling back to safer method: please note that kl-divergence '
1274
+ 'constraints will be disabled.', 'red')
1275
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1276
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1277
+ max_kl_update = None
1254
1278
  self.optimizers = (mu_optimizer, sigma_optimizer)
1279
+ self.max_kl = max_kl_update
1255
1280
 
1256
1281
  def __str__(self) -> str:
1257
1282
  return (f'PGPE hyper-parameters:\n'
1258
- f' method ={self.__class__.__name__}\n'
1259
- f' batch_size ={self.batch_size}\n'
1260
- f' init_sigma ={self.init_sigma}\n'
1261
- f' sigma_range ={self.sigma_range}\n'
1262
- f' scale_reward ={self.scale_reward}\n'
1263
- f' super_symmetric={self.super_symmetric}\n'
1264
- f' accurate ={self.super_symmetric_accurate}\n'
1265
- f' optimizer ={self.optimizer_name}\n'
1283
+ f' method ={self.__class__.__name__}\n'
1284
+ f' batch_size ={self.batch_size}\n'
1285
+ f' init_sigma ={self.init_sigma}\n'
1286
+ f' sigma_range ={self.sigma_range}\n'
1287
+ f' scale_reward ={self.scale_reward}\n'
1288
+ f' min_reward_scale ={self.min_reward_scale}\n'
1289
+ f' super_symmetric ={self.super_symmetric}\n'
1290
+ f' accurate ={self.super_symmetric_accurate}\n'
1291
+ f' optimizer ={self.optimizer_name}\n'
1266
1292
  f' optimizer_kwargs:\n'
1267
1293
  f' mu ={self.optimizer_kwargs_mu}\n'
1268
1294
  f' sigma={self.optimizer_kwargs_sigma}\n'
1295
+ f' start_entropy_coeff={self.start_entropy_coeff}\n'
1296
+ f' end_entropy_coeff ={self.end_entropy_coeff}\n'
1297
+ f' max_kl_update ={self.max_kl}\n'
1269
1298
  )
1270
1299
 
1271
1300
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1272
- MIN_NORM = 1e-5
1273
1301
  sigma0 = self.init_sigma
1274
1302
  sigma_range = self.sigma_range
1275
1303
  scale_reward = self.scale_reward
1304
+ min_reward_scale = self.min_reward_scale
1276
1305
  super_symmetric = self.super_symmetric
1277
1306
  super_symmetric_accurate = self.super_symmetric_accurate
1278
1307
  batch_size = self.batch_size
1279
1308
  optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1280
-
1281
- # initializer
1309
+ max_kl = self.max_kl
1310
+
1311
+ # entropy regularization penalty is decayed exponentially by elapsed budget
1312
+ start_entropy_coeff = self.start_entropy_coeff
1313
+ if start_entropy_coeff == 0:
1314
+ entropy_coeff_decay = 0
1315
+ else:
1316
+ entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
1317
+
1318
+ # ***********************************************************************
1319
+ # INITIALIZATION OF POLICY
1320
+ #
1321
+ # ***********************************************************************
1322
+
1282
1323
  def _jax_wrapped_pgpe_init(key, policy_params):
1283
1324
  mu = policy_params
1284
1325
  sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
@@ -1289,7 +1330,11 @@ class GaussianPGPE(PGPE):
1289
1330
 
1290
1331
  self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1291
1332
 
1292
- # parameter sampling functions
1333
+ # ***********************************************************************
1334
+ # PARAMETER SAMPLING FUNCTIONS
1335
+ #
1336
+ # ***********************************************************************
1337
+
1293
1338
  def _jax_wrapped_mu_noise(key, sigma):
1294
1339
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1295
1340
 
@@ -1299,19 +1344,20 @@ class GaussianPGPE(PGPE):
1299
1344
  a = (sigma - jnp.abs(epsilon)) / sigma
1300
1345
  if super_symmetric_accurate:
1301
1346
  aa = jnp.abs(a)
1347
+ aa3 = jnp.power(aa, 3)
1302
1348
  epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1303
1349
  a <= 0,
1304
- jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1305
- jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1350
+ jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
1351
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
1306
1352
  )
1307
1353
  else:
1308
1354
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1309
1355
  return epsilon_star
1310
1356
 
1311
1357
  def _jax_wrapped_sample_params(key, mu, sigma):
1312
- keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1313
- keys_pytree = jax.tree_util.tree_unflatten(
1314
- treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1358
+ treedef = jax.tree_util.tree_structure(sigma)
1359
+ keys = random.split(key, num=treedef.num_leaves)
1360
+ keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1315
1361
  epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1316
1362
  p1 = jax.tree_map(jnp.add, mu, epsilon)
1317
1363
  p2 = jax.tree_map(jnp.subtract, mu, epsilon)
@@ -1321,14 +1367,18 @@ class GaussianPGPE(PGPE):
1321
1367
  p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1322
1368
  else:
1323
1369
  epsilon_star, p3, p4 = epsilon, p1, p2
1324
- return (p1, p2, p3, p4), (epsilon, epsilon_star)
1370
+ return p1, p2, p3, p4, epsilon, epsilon_star
1325
1371
 
1326
- # policy gradient update functions
1372
+ # ***********************************************************************
1373
+ # POLICY GRADIENT CALCULATION
1374
+ #
1375
+ # ***********************************************************************
1376
+
1327
1377
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1328
1378
  if super_symmetric:
1329
1379
  if scale_reward:
1330
- scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1331
- scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1380
+ scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1381
+ scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
1332
1382
  else:
1333
1383
  scale1 = scale2 = 1.0
1334
1384
  r_mu1 = (r1 - r2) / (2 * scale1)
@@ -1336,37 +1386,37 @@ class GaussianPGPE(PGPE):
1336
1386
  grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1337
1387
  else:
1338
1388
  if scale_reward:
1339
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1389
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1340
1390
  else:
1341
1391
  scale = 1.0
1342
1392
  r_mu = (r1 - r2) / (2 * scale)
1343
1393
  grad = -r_mu * epsilon
1344
1394
  return grad
1345
1395
 
1346
- def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1396
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1347
1397
  if super_symmetric:
1348
1398
  mask = r1 + r2 >= r3 + r4
1349
1399
  epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1350
- s = epsilon_tau * epsilon_tau / sigma - sigma
1400
+ s = jnp.square(epsilon_tau) / sigma - sigma
1351
1401
  if scale_reward:
1352
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1402
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
1353
1403
  else:
1354
1404
  scale = 1.0
1355
1405
  r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1356
1406
  else:
1357
- s = epsilon * epsilon / sigma - sigma
1407
+ s = jnp.square(epsilon) / sigma - sigma
1358
1408
  if scale_reward:
1359
- scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1409
+ scale = jnp.maximum(min_reward_scale, jnp.abs(m))
1360
1410
  else:
1361
1411
  scale = 1.0
1362
1412
  r_sigma = (r1 + r2) / (2 * scale)
1363
- grad = -r_sigma * s
1413
+ grad = -(r_sigma * s + ent / sigma)
1364
1414
  return grad
1365
1415
 
1366
- def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1416
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1367
1417
  policy_hyperparams, subs, model_params):
1368
1418
  key, subkey = random.split(key)
1369
- (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1419
+ p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
1370
1420
  key, mu, sigma)
1371
1421
  r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1372
1422
  r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
@@ -1384,42 +1434,76 @@ class GaussianPGPE(PGPE):
1384
1434
  epsilon, epsilon_star
1385
1435
  )
1386
1436
  grad_sigma = jax.tree_map(
1387
- partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1437
+ partial(_jax_wrapped_sigma_grad,
1438
+ r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1388
1439
  epsilon, epsilon_star, sigma
1389
1440
  )
1390
1441
  return grad_mu, grad_sigma, r_max
1391
1442
 
1392
- def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1443
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
1393
1444
  policy_hyperparams, subs, model_params):
1394
1445
  mu, sigma = pgpe_params
1395
1446
  if batch_size == 1:
1396
1447
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1397
- key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1448
+ key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1398
1449
  else:
1399
1450
  keys = random.split(key, num=batch_size)
1400
1451
  mu_grads, sigma_grads, r_maxs = jax.vmap(
1401
1452
  _jax_wrapped_pgpe_grad,
1402
- in_axes=(0, None, None, None, None, None, None)
1403
- )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1453
+ in_axes=(0, None, None, None, None, None, None, None)
1454
+ )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1404
1455
  mu_grad, sigma_grad = jax.tree_map(
1405
1456
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1406
1457
  new_r_max = jnp.max(r_maxs)
1407
1458
  return mu_grad, sigma_grad, new_r_max
1459
+
1460
+ # ***********************************************************************
1461
+ # PARAMETER UPDATE
1462
+ #
1463
+ # ***********************************************************************
1408
1464
 
1409
- def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1465
+ def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
1466
+ return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
1467
+ jnp.square(old_sigma / sigma) +
1468
+ jnp.square((mu - old_mu) / sigma) - 1)
1469
+
1470
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1410
1471
  policy_hyperparams, subs, model_params,
1411
1472
  pgpe_opt_state):
1473
+ # regular update
1412
1474
  mu, sigma = pgpe_params
1413
1475
  mu_state, sigma_state = pgpe_opt_state
1476
+ ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1414
1477
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1415
- key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1478
+ key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1416
1479
  mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1417
1480
  sigma_updates, new_sigma_state = sigma_optimizer.update(
1418
1481
  sigma_grad, sigma_state, params=sigma)
1419
1482
  new_mu = optax.apply_updates(mu, mu_updates)
1420
- new_mu, converged = projection(new_mu, policy_hyperparams)
1421
1483
  new_sigma = optax.apply_updates(sigma, sigma_updates)
1422
1484
  new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1485
+
1486
+ # respect KL divergence contraint with old parameters
1487
+ if max_kl is not None:
1488
+ old_mu_lr = new_mu_state.hyperparams['learning_rate']
1489
+ old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1490
+ kl_terms = jax.tree_map(
1491
+ _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1492
+ total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1493
+ kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1494
+ mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1495
+ sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1496
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1497
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1498
+ sigma_grad, sigma_state, params=sigma)
1499
+ new_mu = optax.apply_updates(mu, mu_updates)
1500
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1501
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1502
+ new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1503
+ new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1504
+
1505
+ # apply projection step and finalize results
1506
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1423
1507
  new_pgpe_params = (new_mu, new_sigma)
1424
1508
  new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1425
1509
  policy_params = new_mu
@@ -1462,14 +1546,14 @@ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
1462
1546
  @jax.jit
1463
1547
  def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
1464
1548
  mu = jnp.mean(returns)
1465
- msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
1549
+ msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
1466
1550
  return mu - 0.5 * beta * msd
1467
1551
 
1468
1552
 
1469
1553
  @jax.jit
1470
1554
  def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1471
1555
  mu = jnp.mean(returns)
1472
- msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
1556
+ msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
1473
1557
  return mu - 0.5 * beta * msv
1474
1558
 
1475
1559
 
@@ -1768,7 +1852,6 @@ r"""
1768
1852
 
1769
1853
  # optimization
1770
1854
  self.update = self._jax_update(train_loss)
1771
- self.check_zero_grad = self._jax_check_zero_gradients()
1772
1855
 
1773
1856
  # pgpe option
1774
1857
  if self.use_pgpe:
@@ -1831,6 +1914,12 @@ r"""
1831
1914
  projection = self.plan.projection
1832
1915
  use_ls = self.line_search_kwargs is not None
1833
1916
 
1917
+ # check if the gradients are all zeros
1918
+ def _jax_wrapped_zero_gradients(grad):
1919
+ leaves, _ = jax.tree_util.tree_flatten(
1920
+ jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
1921
+ return jnp.all(jnp.asarray(leaves))
1922
+
1834
1923
  # calculate the plan gradient w.r.t. return loss and update optimizer
1835
1924
  # also perform a projection step to satisfy constraints on actions
1836
1925
  def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
@@ -1855,23 +1944,12 @@ r"""
1855
1944
  policy_params, converged = projection(policy_params, policy_hyperparams)
1856
1945
  log['grad'] = grad
1857
1946
  log['updates'] = updates
1947
+ zero_grads = _jax_wrapped_zero_gradients(grad)
1858
1948
  return policy_params, converged, opt_state, opt_aux, \
1859
- loss_val, log, model_params
1949
+ loss_val, log, model_params, zero_grads
1860
1950
 
1861
1951
  return jax.jit(_jax_wrapped_plan_update)
1862
1952
 
1863
- def _jax_check_zero_gradients(self):
1864
-
1865
- def _jax_wrapped_zero_gradient(grad):
1866
- return jnp.allclose(grad, 0)
1867
-
1868
- def _jax_wrapped_zero_gradients(grad):
1869
- leaves, _ = jax.tree_util.tree_flatten(
1870
- jax.tree_map(_jax_wrapped_zero_gradient, grad))
1871
- return jnp.all(jnp.asarray(leaves))
1872
-
1873
- return jax.jit(_jax_wrapped_zero_gradients)
1874
-
1875
1953
  def _batched_init_subs(self, subs):
1876
1954
  rddl = self.rddl
1877
1955
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -2175,11 +2253,12 @@ r"""
2175
2253
  # ======================================================================
2176
2254
 
2177
2255
  # initialize running statistics
2178
- best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
2256
+ best_params, best_loss, best_grad = policy_params, jnp.inf, None
2179
2257
  last_iter_improve = 0
2180
2258
  rolling_test_loss = RollingMean(test_rolling_window)
2181
2259
  log = {}
2182
2260
  status = JaxPlannerStatus.NORMAL
2261
+ progress_percent = 0
2183
2262
 
2184
2263
  # initialize stopping criterion
2185
2264
  if stopping_rule is not None:
@@ -2191,18 +2270,19 @@ r"""
2191
2270
  dashboard_id, dashboard.get_planner_info(self),
2192
2271
  key=dash_key, viz=self.dashboard_viz)
2193
2272
 
2273
+ # progress bar
2274
+ if print_progress:
2275
+ progress_bar = tqdm(None, total=100, position=tqdm_position,
2276
+ bar_format='{l_bar}{bar}| {elapsed} {postfix}')
2277
+ else:
2278
+ progress_bar = None
2279
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2280
+
2194
2281
  # ======================================================================
2195
2282
  # MAIN TRAINING LOOP BEGINS
2196
2283
  # ======================================================================
2197
2284
 
2198
- iters = range(epochs)
2199
- if print_progress:
2200
- iters = tqdm(iters, total=100,
2201
- bar_format='{l_bar}{bar}| {elapsed} {postfix}',
2202
- position=tqdm_position)
2203
- position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2204
-
2205
- for it in iters:
2285
+ for it in range(epochs):
2206
2286
 
2207
2287
  # ==================================================================
2208
2288
  # NEXT GRADIENT DESCENT STEP
@@ -2213,8 +2293,9 @@ r"""
2213
2293
  # update the parameters of the plan
2214
2294
  key, subkey = random.split(key)
2215
2295
  (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2216
- model_params) = self.update(subkey, policy_params, policy_hyperparams,
2217
- train_subs, model_params, opt_state, opt_aux)
2296
+ model_params, zero_grads) = self.update(
2297
+ subkey, policy_params, policy_hyperparams, train_subs, model_params,
2298
+ opt_state, opt_aux)
2218
2299
  test_loss, (test_log, model_params_test) = self.test_loss(
2219
2300
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2220
2301
  test_loss_smooth = rolling_test_loss.update(test_loss)
@@ -2224,8 +2305,9 @@ r"""
2224
2305
  if self.use_pgpe:
2225
2306
  key, subkey = random.split(key)
2226
2307
  pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2227
- self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2228
- test_subs, model_params, pgpe_opt_state)
2308
+ self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2309
+ policy_hyperparams, test_subs, model_params_test,
2310
+ pgpe_opt_state)
2229
2311
  pgpe_loss, _ = self.test_loss(
2230
2312
  subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2231
2313
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
@@ -2252,7 +2334,7 @@ r"""
2252
2334
  # ==================================================================
2253
2335
 
2254
2336
  # no progress
2255
- if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
2337
+ if (not pgpe_improve) and zero_grads:
2256
2338
  status = JaxPlannerStatus.NO_PROGRESS
2257
2339
 
2258
2340
  # constraint satisfaction problem
@@ -2311,14 +2393,15 @@ r"""
2311
2393
 
2312
2394
  # if the progress bar is used
2313
2395
  if print_progress:
2314
- iters.n = progress_percent
2315
- iters.set_description(
2396
+ progress_bar.set_description(
2316
2397
  f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2317
2398
  f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2318
2399
  f'{status.value} status / {total_pgpe_it:6} pgpe',
2319
2400
  refresh=False
2320
2401
  )
2321
- iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
2402
+ progress_bar.set_postfix_str(
2403
+ f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
2404
+ progress_bar.update(progress_percent - progress_bar.n)
2322
2405
 
2323
2406
  # dash-board
2324
2407
  if dashboard is not None:
@@ -2339,7 +2422,7 @@ r"""
2339
2422
 
2340
2423
  # release resources
2341
2424
  if print_progress:
2342
- iters.close()
2425
+ progress_bar.close()
2343
2426
 
2344
2427
  # validate the test return
2345
2428
  if log:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.1
3
+ Version: 2.3
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -58,8 +58,11 @@ Dynamic: summary
58
58
 
59
59
  Purpose:
60
60
 
61
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
62
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
61
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
62
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
63
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
64
+ 4. interactive dashboard for dyanmic visualization and debugging
65
+ 5. hybridization with parameter-exploring policy gradients.
63
66
 
64
67
  Some demos of solved problems by JaxPlan:
65
68
 
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
235
238
 
236
239
  ## Tuning the Planner
237
240
 
238
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
239
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
241
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
+
243
+ ```shell
244
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ ```
246
+
247
+ where:
248
+ - ``domain`` is the domain identifier as specified in rddlrepository
249
+ - ``instance`` is the instance identifier
250
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
251
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+
256
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
240
258
 
241
259
  ```ini
242
260
  [Model]
@@ -260,7 +278,7 @@ train_on_reset=True
260
278
 
261
279
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
262
280
 
263
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
281
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
264
282
 
265
283
  ```python
266
284
  import pyRDDLGym
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
292
310
  gp_iters=iters)
293
311
  tuning.tune(key=42, log_file='path/to/log.csv')
294
312
  ```
295
-
296
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
297
-
298
- ```shell
299
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
300
- ```
301
-
302
- where:
303
- - ``domain`` is the domain identifier as specified in rddlrepository
304
- - ``instance`` is the instance identifier
305
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
306
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
307
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
308
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
309
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
310
-
313
+
311
314
 
312
315
  ## Simulation
313
316
 
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='2.1',
22
+ version='2.3',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -1 +0,0 @@
1
- __version__ = '2.1'
File without changes
File without changes