pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.3'
1
+ __version__ = '0.4'
@@ -1,22 +1,11 @@
1
1
  from functools import partial
2
+ import traceback
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
2
5
  import jax
3
6
  import jax.numpy as jnp
4
7
  import jax.random as random
5
8
  import jax.scipy as scipy
6
- import traceback
7
- from typing import Any, Callable, Dict, List, Optional
8
-
9
- from pyRDDLGym.core.debug.exception import raise_warning
10
-
11
- # more robust approach - if user does not have this or broken try to continue
12
- try:
13
- from tensorflow_probability.substrates import jax as tfp
14
- except Exception:
15
- raise_warning('Failed to import tensorflow-probability: '
16
- 'compilation of some complex distributions '
17
- '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
18
- traceback.print_exc()
19
- tfp = None
20
9
 
21
10
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
22
11
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -25,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
25
14
  from pyRDDLGym.core.constraints import RDDLConstraints
26
15
  from pyRDDLGym.core.debug.exception import (
27
16
  print_stack_trace,
17
+ raise_warning,
28
18
  RDDLInvalidNumberOfArgumentsError,
29
19
  RDDLNotImplementedError
30
20
  )
31
21
  from pyRDDLGym.core.debug.logger import Logger
32
22
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
33
23
 
24
+ # more robust approach - if user does not have this or broken try to continue
25
+ try:
26
+ from tensorflow_probability.substrates import jax as tfp
27
+ except Exception:
28
+ raise_warning('Failed to import tensorflow-probability: '
29
+ 'compilation of some complex distributions '
30
+ '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
31
+ traceback.print_exc()
32
+ tfp = None
33
+
34
34
 
35
35
  # ===========================================================================
36
36
  # EXACT RDDL TO JAX COMPILATION RULES
@@ -87,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
87
87
  def _function_if_exact_named():
88
88
 
89
89
  def _jax_wrapped_if_exact(c, a, b, param):
90
- return jnp.where(c, a, b)
90
+ return jnp.where(c > 0.5, a, b)
91
91
 
92
92
  return _jax_wrapped_if_exact
93
93
 
@@ -114,16 +114,27 @@ def _function_bernoulli_exact_named():
114
114
  def _function_discrete_exact_named():
115
115
 
116
116
  def _jax_wrapped_discrete_exact(key, prob, param):
117
- logits = jnp.log(prob)
118
- sample = random.categorical(key=key, logits=logits, axis=-1)
119
- out_of_bounds = jnp.logical_not(jnp.logical_and(
120
- jnp.all(prob >= 0),
121
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
122
- return sample, out_of_bounds
117
+ return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
123
118
 
124
119
  return _jax_wrapped_discrete_exact
125
120
 
126
121
 
122
+ def _function_poisson_exact_named():
123
+
124
+ def _jax_wrapped_poisson_exact(key, rate, param):
125
+ return random.poisson(key=key, lam=rate, dtype=jnp.int64)
126
+
127
+ return _jax_wrapped_poisson_exact
128
+
129
+
130
+ def _function_geometric_exact_named():
131
+
132
+ def _jax_wrapped_geometric_exact(key, prob, param):
133
+ return random.geometric(key=key, p=prob, dtype=jnp.int64)
134
+
135
+ return _jax_wrapped_geometric_exact
136
+
137
+
127
138
  class JaxRDDLCompiler:
128
139
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
129
140
  All operations are identical to their numpy equivalents.
@@ -210,12 +221,12 @@ class JaxRDDLCompiler:
210
221
  }
211
222
 
212
223
  EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
213
-
214
224
  EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
215
225
 
216
226
  EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
217
-
218
227
  EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
228
+ EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
229
+ EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
219
230
 
220
231
  def __init__(self, rddl: RDDLLiftedModel,
221
232
  allow_synchronous_state: bool=True,
@@ -289,6 +300,8 @@ class JaxRDDLCompiler:
289
300
  self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
290
301
  self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
291
302
  self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
303
+ self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
304
+ self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
292
305
 
293
306
  # ===========================================================================
294
307
  # main compilation subroutines
@@ -996,13 +1009,14 @@ class JaxRDDLCompiler:
996
1009
  jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
997
1010
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
998
1011
 
999
- elif n == 2:
1000
- lhs, rhs = args
1001
- jax_lhs = self._jax(lhs, info)
1002
- jax_rhs = self._jax(rhs, info)
1012
+ elif n == 2 or (n >= 2 and op in {'*', '+'}):
1013
+ jax_exprs = [self._jax(arg, info) for arg in args]
1003
1014
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1004
- return self._jax_binary(
1005
- jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
1015
+ result = jax_exprs[0]
1016
+ for jax_rhs in jax_exprs[1:]:
1017
+ result = self._jax_binary(
1018
+ result, jax_rhs, jax_op, jax_param, at_least_int=True)
1019
+ return result
1006
1020
 
1007
1021
  JaxRDDLCompiler._check_num_args(expr, 2)
1008
1022
 
@@ -1046,13 +1060,14 @@ class JaxRDDLCompiler:
1046
1060
  jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
1047
1061
  return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
1048
1062
 
1049
- elif n == 2:
1050
- lhs, rhs = args
1051
- jax_lhs = self._jax(lhs, info)
1052
- jax_rhs = self._jax(rhs, info)
1063
+ elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
1064
+ jax_exprs = [self._jax(arg, info) for arg in args]
1053
1065
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1054
- return self._jax_binary(
1055
- jax_lhs, jax_rhs, jax_op, jax_param, check_dtype=bool)
1066
+ result = jax_exprs[0]
1067
+ for jax_rhs in jax_exprs[1:]:
1068
+ result = self._jax_binary(
1069
+ result, jax_rhs, jax_op, jax_param, check_dtype=bool)
1070
+ return result
1056
1071
 
1057
1072
  JaxRDDLCompiler._check_num_args(expr, 2)
1058
1073
 
@@ -1165,16 +1180,17 @@ class JaxRDDLCompiler:
1165
1180
  return _jax_wrapped_if_then_else
1166
1181
 
1167
1182
  def _jax_switch(self, expr, info):
1183
+ pred, *_ = expr.args
1168
1184
 
1169
- # if expression is non-fluent, always use the exact operation
1170
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1185
+ # if predicate is non-fluent, always use the exact operation
1186
+ # case conditions are currently only literals so they are non-fluent
1187
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1171
1188
  switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1172
1189
  else:
1173
1190
  switch_op = self.SWITCH_HELPER
1174
1191
  jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1175
1192
 
1176
1193
  # recursively compile predicate
1177
- pred, *_ = expr.args
1178
1194
  jax_pred = self._jax(pred, info)
1179
1195
 
1180
1196
  # recursively compile cases
@@ -1426,15 +1442,24 @@ class JaxRDDLCompiler:
1426
1442
  def _jax_poisson(self, expr, info):
1427
1443
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
1428
1444
  JaxRDDLCompiler._check_num_args(expr, 1)
1429
-
1430
1445
  arg_rate, = expr.args
1446
+
1447
+ # if rate is non-fluent, always use the exact operation
1448
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1449
+ poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1450
+ else:
1451
+ poisson_op = self.POISSON_HELPER
1452
+ jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
1453
+
1454
+ # recursively compile arguments
1431
1455
  jax_rate = self._jax(arg_rate, info)
1432
1456
 
1433
1457
  # uses the implicit JAX subroutine
1434
1458
  def _jax_wrapped_distribution_poisson(x, params, key):
1435
1459
  rate, key, err = jax_rate(x, params, key)
1436
1460
  key, subkey = random.split(key)
1437
- sample = random.poisson(key=subkey, lam=rate, dtype=self.INT)
1461
+ param = params.get(jax_param, None)
1462
+ sample = jax_poisson(subkey, rate, param).astype(self.INT)
1438
1463
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1439
1464
  err |= (out_of_bounds * ERR)
1440
1465
  return sample, key, err
@@ -1535,33 +1560,25 @@ class JaxRDDLCompiler:
1535
1560
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1536
1561
  JaxRDDLCompiler._check_num_args(expr, 1)
1537
1562
  arg_prob, = expr.args
1563
+
1564
+ # if prob is non-fluent, always use the exact operation
1565
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1566
+ geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1567
+ else:
1568
+ geom_op = self.GEOMETRIC_HELPER
1569
+ jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
1570
+
1571
+ # recursively compile arguments
1538
1572
  jax_prob = self._jax(arg_prob, info)
1539
1573
 
1540
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1541
-
1542
- # prob is non-fluent: do not reparameterize
1543
- def _jax_wrapped_distribution_geometric(x, params, key):
1544
- prob, key, err = jax_prob(x, params, key)
1545
- key, subkey = random.split(key)
1546
- sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1547
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1548
- err |= (out_of_bounds * ERR)
1549
- return sample, key, err
1550
-
1551
- else:
1552
- floor_op, jax_param = self._unwrap(
1553
- self.KNOWN_UNARY['floor'], expr.id, info)
1554
-
1555
- # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1556
- def _jax_wrapped_distribution_geometric(x, params, key):
1557
- prob, key, err = jax_prob(x, params, key)
1558
- key, subkey = random.split(key)
1559
- U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1560
- param = params.get(jax_param, None)
1561
- sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
1562
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1563
- err |= (out_of_bounds * ERR)
1564
- return sample, key, err
1574
+ def _jax_wrapped_distribution_geometric(x, params, key):
1575
+ prob, key, err = jax_prob(x, params, key)
1576
+ key, subkey = random.split(key)
1577
+ param = params.get(jax_param, None)
1578
+ sample = jax_geom(subkey, prob, param).astype(self.INT)
1579
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1580
+ err |= (out_of_bounds * ERR)
1581
+ return sample, key, err
1565
1582
 
1566
1583
  return _jax_wrapped_distribution_geometric
1567
1584
 
@@ -1770,7 +1787,10 @@ class JaxRDDLCompiler:
1770
1787
  # dispatch to sampling subroutine
1771
1788
  key, subkey = random.split(key)
1772
1789
  param = params.get(jax_param, None)
1773
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1790
+ sample = jax_discrete(subkey, prob, param)
1791
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1792
+ jnp.all(prob >= 0),
1793
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1774
1794
  error |= (out_of_bounds * ERR)
1775
1795
  return sample, key, error
1776
1796
 
@@ -1803,7 +1823,10 @@ class JaxRDDLCompiler:
1803
1823
  # dispatch to sampling subroutine
1804
1824
  key, subkey = random.split(key)
1805
1825
  param = params.get(jax_param, None)
1806
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1826
+ sample = jax_discrete(subkey, prob, param)
1827
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1828
+ jnp.all(prob >= 0),
1829
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1807
1830
  error |= (out_of_bounds * ERR)
1808
1831
  return sample, key, error
1809
1832