pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +90 -67
- pyRDDLGym_jax/core/logic.py +286 -82
- pyRDDLGym_jax/core/planner.py +191 -97
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +58 -63
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +4 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
- pyRDDLGym_jax/examples/run_tune.py +1 -3
- pyRDDLGym_jax-0.5.dist-info/METADATA +278 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/RECORD +17 -17
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.3.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.4'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1,22 +1,11 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
|
|
2
5
|
import jax
|
|
3
6
|
import jax.numpy as jnp
|
|
4
7
|
import jax.random as random
|
|
5
8
|
import jax.scipy as scipy
|
|
6
|
-
import traceback
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
8
|
-
|
|
9
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
10
|
-
|
|
11
|
-
# more robust approach - if user does not have this or broken try to continue
|
|
12
|
-
try:
|
|
13
|
-
from tensorflow_probability.substrates import jax as tfp
|
|
14
|
-
except Exception:
|
|
15
|
-
raise_warning('Failed to import tensorflow-probability: '
|
|
16
|
-
'compilation of some complex distributions '
|
|
17
|
-
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
18
|
-
traceback.print_exc()
|
|
19
|
-
tfp = None
|
|
20
9
|
|
|
21
10
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
22
11
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -25,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
|
|
|
25
14
|
from pyRDDLGym.core.constraints import RDDLConstraints
|
|
26
15
|
from pyRDDLGym.core.debug.exception import (
|
|
27
16
|
print_stack_trace,
|
|
17
|
+
raise_warning,
|
|
28
18
|
RDDLInvalidNumberOfArgumentsError,
|
|
29
19
|
RDDLNotImplementedError
|
|
30
20
|
)
|
|
31
21
|
from pyRDDLGym.core.debug.logger import Logger
|
|
32
22
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
33
23
|
|
|
24
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
25
|
+
try:
|
|
26
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
27
|
+
except Exception:
|
|
28
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
29
|
+
'compilation of some complex distributions '
|
|
30
|
+
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
31
|
+
traceback.print_exc()
|
|
32
|
+
tfp = None
|
|
33
|
+
|
|
34
34
|
|
|
35
35
|
# ===========================================================================
|
|
36
36
|
# EXACT RDDL TO JAX COMPILATION RULES
|
|
@@ -87,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
|
|
|
87
87
|
def _function_if_exact_named():
|
|
88
88
|
|
|
89
89
|
def _jax_wrapped_if_exact(c, a, b, param):
|
|
90
|
-
return jnp.where(c, a, b)
|
|
90
|
+
return jnp.where(c > 0.5, a, b)
|
|
91
91
|
|
|
92
92
|
return _jax_wrapped_if_exact
|
|
93
93
|
|
|
@@ -114,16 +114,27 @@ def _function_bernoulli_exact_named():
|
|
|
114
114
|
def _function_discrete_exact_named():
|
|
115
115
|
|
|
116
116
|
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
117
|
-
|
|
118
|
-
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
119
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
120
|
-
jnp.all(prob >= 0),
|
|
121
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
122
|
-
return sample, out_of_bounds
|
|
117
|
+
return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
123
118
|
|
|
124
119
|
return _jax_wrapped_discrete_exact
|
|
125
120
|
|
|
126
121
|
|
|
122
|
+
def _function_poisson_exact_named():
|
|
123
|
+
|
|
124
|
+
def _jax_wrapped_poisson_exact(key, rate, param):
|
|
125
|
+
return random.poisson(key=key, lam=rate, dtype=jnp.int64)
|
|
126
|
+
|
|
127
|
+
return _jax_wrapped_poisson_exact
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _function_geometric_exact_named():
|
|
131
|
+
|
|
132
|
+
def _jax_wrapped_geometric_exact(key, prob, param):
|
|
133
|
+
return random.geometric(key=key, p=prob, dtype=jnp.int64)
|
|
134
|
+
|
|
135
|
+
return _jax_wrapped_geometric_exact
|
|
136
|
+
|
|
137
|
+
|
|
127
138
|
class JaxRDDLCompiler:
|
|
128
139
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
129
140
|
All operations are identical to their numpy equivalents.
|
|
@@ -210,12 +221,12 @@ class JaxRDDLCompiler:
|
|
|
210
221
|
}
|
|
211
222
|
|
|
212
223
|
EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
|
|
213
|
-
|
|
214
224
|
EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
|
|
215
225
|
|
|
216
226
|
EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
|
|
217
|
-
|
|
218
227
|
EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
|
|
228
|
+
EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
|
|
229
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
|
|
219
230
|
|
|
220
231
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
221
232
|
allow_synchronous_state: bool=True,
|
|
@@ -289,6 +300,8 @@ class JaxRDDLCompiler:
|
|
|
289
300
|
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
290
301
|
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
291
302
|
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
303
|
+
self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
304
|
+
self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
292
305
|
|
|
293
306
|
# ===========================================================================
|
|
294
307
|
# main compilation subroutines
|
|
@@ -996,13 +1009,14 @@ class JaxRDDLCompiler:
|
|
|
996
1009
|
jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
|
|
997
1010
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
998
1011
|
|
|
999
|
-
elif n == 2:
|
|
1000
|
-
|
|
1001
|
-
jax_lhs = self._jax(lhs, info)
|
|
1002
|
-
jax_rhs = self._jax(rhs, info)
|
|
1012
|
+
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
1013
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1003
1014
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1004
|
-
|
|
1005
|
-
|
|
1015
|
+
result = jax_exprs[0]
|
|
1016
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1017
|
+
result = self._jax_binary(
|
|
1018
|
+
result, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
1019
|
+
return result
|
|
1006
1020
|
|
|
1007
1021
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1008
1022
|
|
|
@@ -1046,13 +1060,14 @@ class JaxRDDLCompiler:
|
|
|
1046
1060
|
jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
|
|
1047
1061
|
return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
|
|
1048
1062
|
|
|
1049
|
-
elif n == 2:
|
|
1050
|
-
|
|
1051
|
-
jax_lhs = self._jax(lhs, info)
|
|
1052
|
-
jax_rhs = self._jax(rhs, info)
|
|
1063
|
+
elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
|
|
1064
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1053
1065
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1054
|
-
|
|
1055
|
-
|
|
1066
|
+
result = jax_exprs[0]
|
|
1067
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1068
|
+
result = self._jax_binary(
|
|
1069
|
+
result, jax_rhs, jax_op, jax_param, check_dtype=bool)
|
|
1070
|
+
return result
|
|
1056
1071
|
|
|
1057
1072
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1058
1073
|
|
|
@@ -1165,16 +1180,17 @@ class JaxRDDLCompiler:
|
|
|
1165
1180
|
return _jax_wrapped_if_then_else
|
|
1166
1181
|
|
|
1167
1182
|
def _jax_switch(self, expr, info):
|
|
1183
|
+
pred, *_ = expr.args
|
|
1168
1184
|
|
|
1169
|
-
# if
|
|
1170
|
-
|
|
1185
|
+
# if predicate is non-fluent, always use the exact operation
|
|
1186
|
+
# case conditions are currently only literals so they are non-fluent
|
|
1187
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1171
1188
|
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1172
1189
|
else:
|
|
1173
1190
|
switch_op = self.SWITCH_HELPER
|
|
1174
1191
|
jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
|
|
1175
1192
|
|
|
1176
1193
|
# recursively compile predicate
|
|
1177
|
-
pred, *_ = expr.args
|
|
1178
1194
|
jax_pred = self._jax(pred, info)
|
|
1179
1195
|
|
|
1180
1196
|
# recursively compile cases
|
|
@@ -1426,15 +1442,24 @@ class JaxRDDLCompiler:
|
|
|
1426
1442
|
def _jax_poisson(self, expr, info):
|
|
1427
1443
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1428
1444
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1429
|
-
|
|
1430
1445
|
arg_rate, = expr.args
|
|
1446
|
+
|
|
1447
|
+
# if rate is non-fluent, always use the exact operation
|
|
1448
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1449
|
+
poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
1450
|
+
else:
|
|
1451
|
+
poisson_op = self.POISSON_HELPER
|
|
1452
|
+
jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
|
|
1453
|
+
|
|
1454
|
+
# recursively compile arguments
|
|
1431
1455
|
jax_rate = self._jax(arg_rate, info)
|
|
1432
1456
|
|
|
1433
1457
|
# uses the implicit JAX subroutine
|
|
1434
1458
|
def _jax_wrapped_distribution_poisson(x, params, key):
|
|
1435
1459
|
rate, key, err = jax_rate(x, params, key)
|
|
1436
1460
|
key, subkey = random.split(key)
|
|
1437
|
-
|
|
1461
|
+
param = params.get(jax_param, None)
|
|
1462
|
+
sample = jax_poisson(subkey, rate, param).astype(self.INT)
|
|
1438
1463
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1439
1464
|
err |= (out_of_bounds * ERR)
|
|
1440
1465
|
return sample, key, err
|
|
@@ -1535,33 +1560,25 @@ class JaxRDDLCompiler:
|
|
|
1535
1560
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1536
1561
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1537
1562
|
arg_prob, = expr.args
|
|
1563
|
+
|
|
1564
|
+
# if prob is non-fluent, always use the exact operation
|
|
1565
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1566
|
+
geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
1567
|
+
else:
|
|
1568
|
+
geom_op = self.GEOMETRIC_HELPER
|
|
1569
|
+
jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
|
|
1570
|
+
|
|
1571
|
+
# recursively compile arguments
|
|
1538
1572
|
jax_prob = self._jax(arg_prob, info)
|
|
1539
1573
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
err |= (out_of_bounds * ERR)
|
|
1549
|
-
return sample, key, err
|
|
1550
|
-
|
|
1551
|
-
else:
|
|
1552
|
-
floor_op, jax_param = self._unwrap(
|
|
1553
|
-
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1554
|
-
|
|
1555
|
-
# reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
|
|
1556
|
-
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1557
|
-
prob, key, err = jax_prob(x, params, key)
|
|
1558
|
-
key, subkey = random.split(key)
|
|
1559
|
-
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1560
|
-
param = params.get(jax_param, None)
|
|
1561
|
-
sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
1562
|
-
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1563
|
-
err |= (out_of_bounds * ERR)
|
|
1564
|
-
return sample, key, err
|
|
1574
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1575
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1576
|
+
key, subkey = random.split(key)
|
|
1577
|
+
param = params.get(jax_param, None)
|
|
1578
|
+
sample = jax_geom(subkey, prob, param).astype(self.INT)
|
|
1579
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1580
|
+
err |= (out_of_bounds * ERR)
|
|
1581
|
+
return sample, key, err
|
|
1565
1582
|
|
|
1566
1583
|
return _jax_wrapped_distribution_geometric
|
|
1567
1584
|
|
|
@@ -1770,7 +1787,10 @@ class JaxRDDLCompiler:
|
|
|
1770
1787
|
# dispatch to sampling subroutine
|
|
1771
1788
|
key, subkey = random.split(key)
|
|
1772
1789
|
param = params.get(jax_param, None)
|
|
1773
|
-
sample
|
|
1790
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1791
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1792
|
+
jnp.all(prob >= 0),
|
|
1793
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1774
1794
|
error |= (out_of_bounds * ERR)
|
|
1775
1795
|
return sample, key, error
|
|
1776
1796
|
|
|
@@ -1803,7 +1823,10 @@ class JaxRDDLCompiler:
|
|
|
1803
1823
|
# dispatch to sampling subroutine
|
|
1804
1824
|
key, subkey = random.split(key)
|
|
1805
1825
|
param = params.get(jax_param, None)
|
|
1806
|
-
sample
|
|
1826
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1827
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1828
|
+
jnp.all(prob >= 0),
|
|
1829
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1807
1830
|
error |= (out_of_bounds * ERR)
|
|
1808
1831
|
return sample, key, error
|
|
1809
1832
|
|