pyRDDLGym-jax 2.2__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -11
- pyRDDLGym_jax/core/logic.py +233 -119
- pyRDDLGym_jax/core/planner.py +489 -218
- pyRDDLGym_jax/core/tuning.py +28 -22
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/RECORD +13 -13
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.4'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -471,8 +471,7 @@ class JaxRDDLCompiler:
|
|
|
471
471
|
return printed
|
|
472
472
|
|
|
473
473
|
def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
|
|
474
|
-
'''Returns a dictionary of additional information about model
|
|
475
|
-
parameters.'''
|
|
474
|
+
'''Returns a dictionary of additional information about model parameters.'''
|
|
476
475
|
result = {}
|
|
477
476
|
for (id, value) in self.model_params.items():
|
|
478
477
|
expr_id = int(str(id).split('_')[0])
|
|
@@ -799,7 +798,7 @@ class JaxRDDLCompiler:
|
|
|
799
798
|
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
800
799
|
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
801
800
|
result = jax_exprs[0]
|
|
802
|
-
for i, jax_rhs in enumerate(jax_exprs[1:]):
|
|
801
|
+
for (i, jax_rhs) in enumerate(jax_exprs[1:]):
|
|
803
802
|
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
804
803
|
result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
|
|
805
804
|
return result
|
|
@@ -1019,6 +1018,9 @@ class JaxRDDLCompiler:
|
|
|
1019
1018
|
# UnnormDiscrete: complete (subclass uses Gumbel-softmax)
|
|
1020
1019
|
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1021
1020
|
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1021
|
+
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1022
|
+
# Binomial (subclass uses Gumbel-softmax or Normal approximation)
|
|
1023
|
+
# NegativeBinomial (subclass uses Poisson-Gamma mixture)
|
|
1022
1024
|
|
|
1023
1025
|
# distributions which seem to support backpropagation (need more testing):
|
|
1024
1026
|
# Beta
|
|
@@ -1026,11 +1028,8 @@ class JaxRDDLCompiler:
|
|
|
1026
1028
|
# Gamma
|
|
1027
1029
|
# ChiSquare
|
|
1028
1030
|
# Dirichlet
|
|
1029
|
-
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1030
1031
|
|
|
1031
1032
|
# distributions with incomplete reparameterization support (TODO):
|
|
1032
|
-
# Binomial
|
|
1033
|
-
# NegativeBinomial
|
|
1034
1033
|
# Multinomial
|
|
1035
1034
|
|
|
1036
1035
|
def _jax_random(self, expr, init_params):
|
|
@@ -1299,8 +1298,17 @@ class JaxRDDLCompiler:
|
|
|
1299
1298
|
def _jax_negative_binomial(self, expr, init_params):
|
|
1300
1299
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1301
1300
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1302
|
-
|
|
1303
1301
|
arg_trials, arg_prob = expr.args
|
|
1302
|
+
|
|
1303
|
+
# if prob is non-fluent, always use the exact operation
|
|
1304
|
+
if self.compile_non_fluent_exact \
|
|
1305
|
+
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1306
|
+
and not self.traced.cached_is_fluent(arg_prob):
|
|
1307
|
+
negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
|
|
1308
|
+
else:
|
|
1309
|
+
negbin_op = self.OPS['sampling']['NegativeBinomial']
|
|
1310
|
+
jax_op = negbin_op(expr.id, init_params)
|
|
1311
|
+
|
|
1304
1312
|
jax_trials = self._jax(arg_trials, init_params)
|
|
1305
1313
|
jax_prob = self._jax(arg_prob, init_params)
|
|
1306
1314
|
|
|
@@ -1308,11 +1316,8 @@ class JaxRDDLCompiler:
|
|
|
1308
1316
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1309
1317
|
trials, key, err2, params = jax_trials(x, params, key)
|
|
1310
1318
|
prob, key, err1, params = jax_prob(x, params, key)
|
|
1311
|
-
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1312
|
-
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1313
1319
|
key, subkey = random.split(key)
|
|
1314
|
-
|
|
1315
|
-
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1320
|
+
sample, params = jax_op(subkey, trials, prob, params)
|
|
1316
1321
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1317
1322
|
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
1318
1323
|
err = err1 | err2 | (out_of_bounds * ERR)
|