pyRDDLGym-jax 2.2__tar.gz → 2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/PKG-INFO +1 -1
  2. pyrddlgym_jax-2.4/pyRDDLGym_jax/__init__.py +1 -0
  3. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/compiler.py +16 -11
  4. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/logic.py +233 -119
  5. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/planner.py +489 -218
  6. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/tuning.py +28 -22
  7. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/run_plan.py +2 -2
  8. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/run_scipy.py +2 -2
  9. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/PKG-INFO +1 -1
  10. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/setup.py +1 -1
  11. pyrddlgym_jax-2.2/pyRDDLGym_jax/__init__.py +0 -1
  12. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/LICENSE +0 -0
  13. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/README.md +0 -0
  14. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/__init__.py +0 -0
  15. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  16. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  17. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/simulator.py +0 -0
  18. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/core/visualization.py +0 -0
  19. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/entry_point.py +0 -0
  20. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/__init__.py +0 -0
  21. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  22. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  23. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  24. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  25. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  26. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  27. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  28. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  29. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  30. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  32. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  33. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  34. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  35. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  36. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  37. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  38. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  39. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  40. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  41. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  42. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  43. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  44. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  45. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  46. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  47. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  48. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  49. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  50. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  51. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  52. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  53. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  54. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  55. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.2
3
+ Version: 2.4
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -0,0 +1 @@
1
+ __version__ = '2.4'
@@ -471,8 +471,7 @@ class JaxRDDLCompiler:
471
471
  return printed
472
472
 
473
473
  def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
474
- '''Returns a dictionary of additional information about model
475
- parameters.'''
474
+ '''Returns a dictionary of additional information about model parameters.'''
476
475
  result = {}
477
476
  for (id, value) in self.model_params.items():
478
477
  expr_id = int(str(id).split('_')[0])
@@ -799,7 +798,7 @@ class JaxRDDLCompiler:
799
798
  elif n == 2 or (n >= 2 and op in {'*', '+'}):
800
799
  jax_exprs = [self._jax(arg, init_params) for arg in args]
801
800
  result = jax_exprs[0]
802
- for i, jax_rhs in enumerate(jax_exprs[1:]):
801
+ for (i, jax_rhs) in enumerate(jax_exprs[1:]):
803
802
  jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
804
803
  result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
805
804
  return result
@@ -1019,6 +1018,9 @@ class JaxRDDLCompiler:
1019
1018
  # UnnormDiscrete: complete (subclass uses Gumbel-softmax)
1020
1019
  # Discrete(p): complete (subclass uses Gumbel-softmax)
1021
1020
  # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1021
+ # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1022
+ # Binomial (subclass uses Gumbel-softmax or Normal approximation)
1023
+ # NegativeBinomial (subclass uses Poisson-Gamma mixture)
1022
1024
 
1023
1025
  # distributions which seem to support backpropagation (need more testing):
1024
1026
  # Beta
@@ -1026,11 +1028,8 @@ class JaxRDDLCompiler:
1026
1028
  # Gamma
1027
1029
  # ChiSquare
1028
1030
  # Dirichlet
1029
- # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1030
1031
 
1031
1032
  # distributions with incomplete reparameterization support (TODO):
1032
- # Binomial
1033
- # NegativeBinomial
1034
1033
  # Multinomial
1035
1034
 
1036
1035
  def _jax_random(self, expr, init_params):
@@ -1299,8 +1298,17 @@ class JaxRDDLCompiler:
1299
1298
  def _jax_negative_binomial(self, expr, init_params):
1300
1299
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1301
1300
  JaxRDDLCompiler._check_num_args(expr, 2)
1302
-
1303
1301
  arg_trials, arg_prob = expr.args
1302
+
1303
+ # if prob is non-fluent, always use the exact operation
1304
+ if self.compile_non_fluent_exact \
1305
+ and not self.traced.cached_is_fluent(arg_trials) \
1306
+ and not self.traced.cached_is_fluent(arg_prob):
1307
+ negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
1308
+ else:
1309
+ negbin_op = self.OPS['sampling']['NegativeBinomial']
1310
+ jax_op = negbin_op(expr.id, init_params)
1311
+
1304
1312
  jax_trials = self._jax(arg_trials, init_params)
1305
1313
  jax_prob = self._jax(arg_prob, init_params)
1306
1314
 
@@ -1308,11 +1316,8 @@ class JaxRDDLCompiler:
1308
1316
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1309
1317
  trials, key, err2, params = jax_trials(x, params, key)
1310
1318
  prob, key, err1, params = jax_prob(x, params, key)
1311
- trials = jnp.asarray(trials, dtype=self.REAL)
1312
- prob = jnp.asarray(prob, dtype=self.REAL)
1313
1319
  key, subkey = random.split(key)
1314
- dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1315
- sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1320
+ sample, params = jax_op(subkey, trials, prob, params)
1316
1321
  out_of_bounds = jnp.logical_not(jnp.all(
1317
1322
  (prob >= 0) & (prob <= 1) & (trials > 0)))
1318
1323
  err = err1 | err2 | (out_of_bounds * ERR)