pyRDDLGym-jax 0.3__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.3'
1
+ __version__ = '0.4'
@@ -1,22 +1,11 @@
1
1
  from functools import partial
2
+ import traceback
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
2
5
  import jax
3
6
  import jax.numpy as jnp
4
7
  import jax.random as random
5
8
  import jax.scipy as scipy
6
- import traceback
7
- from typing import Any, Callable, Dict, List, Optional
8
-
9
- from pyRDDLGym.core.debug.exception import raise_warning
10
-
11
- # more robust approach - if user does not have this or broken try to continue
12
- try:
13
- from tensorflow_probability.substrates import jax as tfp
14
- except Exception:
15
- raise_warning('Failed to import tensorflow-probability: '
16
- 'compilation of some complex distributions '
17
- '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
18
- traceback.print_exc()
19
- tfp = None
20
9
 
21
10
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
22
11
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -25,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
25
14
  from pyRDDLGym.core.constraints import RDDLConstraints
26
15
  from pyRDDLGym.core.debug.exception import (
27
16
  print_stack_trace,
17
+ raise_warning,
28
18
  RDDLInvalidNumberOfArgumentsError,
29
19
  RDDLNotImplementedError
30
20
  )
31
21
  from pyRDDLGym.core.debug.logger import Logger
32
22
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
33
23
 
24
+ # more robust approach - if user does not have this or broken try to continue
25
+ try:
26
+ from tensorflow_probability.substrates import jax as tfp
27
+ except Exception:
28
+ raise_warning('Failed to import tensorflow-probability: '
29
+ 'compilation of some complex distributions '
30
+ '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
31
+ traceback.print_exc()
32
+ tfp = None
33
+
34
34
 
35
35
  # ===========================================================================
36
36
  # EXACT RDDL TO JAX COMPILATION RULES
@@ -87,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
87
87
  def _function_if_exact_named():
88
88
 
89
89
  def _jax_wrapped_if_exact(c, a, b, param):
90
- return jnp.where(c, a, b)
90
+ return jnp.where(c > 0.5, a, b)
91
91
 
92
92
  return _jax_wrapped_if_exact
93
93
 
@@ -114,16 +114,27 @@ def _function_bernoulli_exact_named():
114
114
  def _function_discrete_exact_named():
115
115
 
116
116
  def _jax_wrapped_discrete_exact(key, prob, param):
117
- logits = jnp.log(prob)
118
- sample = random.categorical(key=key, logits=logits, axis=-1)
119
- out_of_bounds = jnp.logical_not(jnp.logical_and(
120
- jnp.all(prob >= 0),
121
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
122
- return sample, out_of_bounds
117
+ return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
123
118
 
124
119
  return _jax_wrapped_discrete_exact
125
120
 
126
121
 
122
+ def _function_poisson_exact_named():
123
+
124
+ def _jax_wrapped_poisson_exact(key, rate, param):
125
+ return random.poisson(key=key, lam=rate, dtype=jnp.int64)
126
+
127
+ return _jax_wrapped_poisson_exact
128
+
129
+
130
+ def _function_geometric_exact_named():
131
+
132
+ def _jax_wrapped_geometric_exact(key, prob, param):
133
+ return random.geometric(key=key, p=prob, dtype=jnp.int64)
134
+
135
+ return _jax_wrapped_geometric_exact
136
+
137
+
127
138
  class JaxRDDLCompiler:
128
139
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
129
140
  All operations are identical to their numpy equivalents.
@@ -210,12 +221,12 @@ class JaxRDDLCompiler:
210
221
  }
211
222
 
212
223
  EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
213
-
214
224
  EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
215
225
 
216
226
  EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
217
-
218
227
  EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
228
+ EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
229
+ EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
219
230
 
220
231
  def __init__(self, rddl: RDDLLiftedModel,
221
232
  allow_synchronous_state: bool=True,
@@ -289,6 +300,8 @@ class JaxRDDLCompiler:
289
300
  self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
290
301
  self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
291
302
  self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
303
+ self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
304
+ self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
292
305
 
293
306
  # ===========================================================================
294
307
  # main compilation subroutines
@@ -996,13 +1009,14 @@ class JaxRDDLCompiler:
996
1009
  jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
997
1010
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
998
1011
 
999
- elif n == 2:
1000
- lhs, rhs = args
1001
- jax_lhs = self._jax(lhs, info)
1002
- jax_rhs = self._jax(rhs, info)
1012
+ elif n == 2 or (n >= 2 and op in {'*', '+'}):
1013
+ jax_exprs = [self._jax(arg, info) for arg in args]
1003
1014
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1004
- return self._jax_binary(
1005
- jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
1015
+ result = jax_exprs[0]
1016
+ for jax_rhs in jax_exprs[1:]:
1017
+ result = self._jax_binary(
1018
+ result, jax_rhs, jax_op, jax_param, at_least_int=True)
1019
+ return result
1006
1020
 
1007
1021
  JaxRDDLCompiler._check_num_args(expr, 2)
1008
1022
 
@@ -1046,13 +1060,14 @@ class JaxRDDLCompiler:
1046
1060
  jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
1047
1061
  return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
1048
1062
 
1049
- elif n == 2:
1050
- lhs, rhs = args
1051
- jax_lhs = self._jax(lhs, info)
1052
- jax_rhs = self._jax(rhs, info)
1063
+ elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
1064
+ jax_exprs = [self._jax(arg, info) for arg in args]
1053
1065
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1054
- return self._jax_binary(
1055
- jax_lhs, jax_rhs, jax_op, jax_param, check_dtype=bool)
1066
+ result = jax_exprs[0]
1067
+ for jax_rhs in jax_exprs[1:]:
1068
+ result = self._jax_binary(
1069
+ result, jax_rhs, jax_op, jax_param, check_dtype=bool)
1070
+ return result
1056
1071
 
1057
1072
  JaxRDDLCompiler._check_num_args(expr, 2)
1058
1073
 
@@ -1165,16 +1180,17 @@ class JaxRDDLCompiler:
1165
1180
  return _jax_wrapped_if_then_else
1166
1181
 
1167
1182
  def _jax_switch(self, expr, info):
1183
+ pred, *_ = expr.args
1168
1184
 
1169
- # if expression is non-fluent, always use the exact operation
1170
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1185
+ # if predicate is non-fluent, always use the exact operation
1186
+ # case conditions are currently only literals so they are non-fluent
1187
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1171
1188
  switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1172
1189
  else:
1173
1190
  switch_op = self.SWITCH_HELPER
1174
1191
  jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1175
1192
 
1176
1193
  # recursively compile predicate
1177
- pred, *_ = expr.args
1178
1194
  jax_pred = self._jax(pred, info)
1179
1195
 
1180
1196
  # recursively compile cases
@@ -1426,15 +1442,24 @@ class JaxRDDLCompiler:
1426
1442
  def _jax_poisson(self, expr, info):
1427
1443
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
1428
1444
  JaxRDDLCompiler._check_num_args(expr, 1)
1429
-
1430
1445
  arg_rate, = expr.args
1446
+
1447
+ # if rate is non-fluent, always use the exact operation
1448
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1449
+ poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1450
+ else:
1451
+ poisson_op = self.POISSON_HELPER
1452
+ jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
1453
+
1454
+ # recursively compile arguments
1431
1455
  jax_rate = self._jax(arg_rate, info)
1432
1456
 
1433
1457
  # uses the implicit JAX subroutine
1434
1458
  def _jax_wrapped_distribution_poisson(x, params, key):
1435
1459
  rate, key, err = jax_rate(x, params, key)
1436
1460
  key, subkey = random.split(key)
1437
- sample = random.poisson(key=subkey, lam=rate, dtype=self.INT)
1461
+ param = params.get(jax_param, None)
1462
+ sample = jax_poisson(subkey, rate, param).astype(self.INT)
1438
1463
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1439
1464
  err |= (out_of_bounds * ERR)
1440
1465
  return sample, key, err
@@ -1535,33 +1560,25 @@ class JaxRDDLCompiler:
1535
1560
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1536
1561
  JaxRDDLCompiler._check_num_args(expr, 1)
1537
1562
  arg_prob, = expr.args
1563
+
1564
+ # if prob is non-fluent, always use the exact operation
1565
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1566
+ geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1567
+ else:
1568
+ geom_op = self.GEOMETRIC_HELPER
1569
+ jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
1570
+
1571
+ # recursively compile arguments
1538
1572
  jax_prob = self._jax(arg_prob, info)
1539
1573
 
1540
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1541
-
1542
- # prob is non-fluent: do not reparameterize
1543
- def _jax_wrapped_distribution_geometric(x, params, key):
1544
- prob, key, err = jax_prob(x, params, key)
1545
- key, subkey = random.split(key)
1546
- sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1547
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1548
- err |= (out_of_bounds * ERR)
1549
- return sample, key, err
1550
-
1551
- else:
1552
- floor_op, jax_param = self._unwrap(
1553
- self.KNOWN_UNARY['floor'], expr.id, info)
1554
-
1555
- # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1556
- def _jax_wrapped_distribution_geometric(x, params, key):
1557
- prob, key, err = jax_prob(x, params, key)
1558
- key, subkey = random.split(key)
1559
- U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1560
- param = params.get(jax_param, None)
1561
- sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
1562
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1563
- err |= (out_of_bounds * ERR)
1564
- return sample, key, err
1574
+ def _jax_wrapped_distribution_geometric(x, params, key):
1575
+ prob, key, err = jax_prob(x, params, key)
1576
+ key, subkey = random.split(key)
1577
+ param = params.get(jax_param, None)
1578
+ sample = jax_geom(subkey, prob, param).astype(self.INT)
1579
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1580
+ err |= (out_of_bounds * ERR)
1581
+ return sample, key, err
1565
1582
 
1566
1583
  return _jax_wrapped_distribution_geometric
1567
1584
 
@@ -1770,7 +1787,10 @@ class JaxRDDLCompiler:
1770
1787
  # dispatch to sampling subroutine
1771
1788
  key, subkey = random.split(key)
1772
1789
  param = params.get(jax_param, None)
1773
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1790
+ sample = jax_discrete(subkey, prob, param)
1791
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1792
+ jnp.all(prob >= 0),
1793
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1774
1794
  error |= (out_of_bounds * ERR)
1775
1795
  return sample, key, error
1776
1796
 
@@ -1803,7 +1823,10 @@ class JaxRDDLCompiler:
1803
1823
  # dispatch to sampling subroutine
1804
1824
  key, subkey = random.split(key)
1805
1825
  param = params.get(jax_param, None)
1806
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1826
+ sample = jax_discrete(subkey, prob, param)
1827
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1828
+ jnp.all(prob >= 0),
1829
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1807
1830
  error |= (out_of_bounds * ERR)
1808
1831
  return sample, key, error
1809
1832
 
@@ -1,11 +1,19 @@
1
+ from typing import Optional, Set
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import jax.random as random
4
- from typing import Optional, Set
5
6
 
6
7
  from pyRDDLGym.core.debug.exception import raise_warning
7
8
 
8
9
 
10
+ # ===========================================================================
11
+ # LOGICAL COMPLEMENT
12
+ # - abstract class
13
+ # - standard complement
14
+ #
15
+ # ===========================================================================
16
+
9
17
  class Complement:
10
18
  '''Base class for approximate logical complement operations.'''
11
19
 
@@ -20,6 +28,13 @@ class StandardComplement(Complement):
20
28
  return 1.0 - x
21
29
 
22
30
 
31
+ # ===========================================================================
32
+ # RELATIONAL OPERATIONS
33
+ # - abstract class
34
+ # - sigmoid comparison
35
+ #
36
+ # ===========================================================================
37
+
23
38
  class Comparison:
24
39
  '''Base class for approximate comparison operations.'''
25
40
 
@@ -44,7 +59,17 @@ class SigmoidComparison(Comparison):
44
59
 
45
60
  def equal(self, x, y, param):
46
61
  return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
47
-
62
+
63
+
64
+ # ===========================================================================
65
+ # TNORMS
66
+ # - abstract tnorm
67
+ # - product tnorm
68
+ # - Godel tnorm
69
+ # - Lukasiewicz tnorm
70
+ # - Yager(p) tnorm
71
+ #
72
+ # ===========================================================================
48
73
 
49
74
  class TNorm:
50
75
  '''Base class for fuzzy differentiable t-norms.'''
@@ -86,8 +111,133 @@ class LukasiewiczTNorm(TNorm):
86
111
 
87
112
  def norms(self, x, axis):
88
113
  return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
114
+
115
+
116
+ class YagerTNorm(TNorm):
117
+ '''Yager t-norm given by the expression
118
+ (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
119
+
120
+ def __init__(self, p=2.0):
121
+ self.p = p
122
+
123
+ def norm(self, x, y):
124
+ base_x = jax.nn.relu(1.0 - x)
125
+ base_y = jax.nn.relu(1.0 - y)
126
+ arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
127
+ return jax.nn.relu(1.0 - arg)
128
+
129
+ def norms(self, x, axis):
130
+ base = jax.nn.relu(1.0 - x)
131
+ arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
132
+ return jax.nn.relu(1.0 - arg)
133
+
134
+
135
+ # ===========================================================================
136
+ # RANDOM SAMPLING
137
+ # - abstract sampler
138
+ # - Gumbel-softmax sampler
139
+ # - determinization
140
+ #
141
+ # ===========================================================================
142
+
143
+ class RandomSampling:
144
+ '''An abstract class that describes how discrete and non-reparameterizable
145
+ random variables are sampled.'''
146
+
147
+ def discrete(self, logic):
148
+ raise NotImplementedError
149
+
150
+ def bernoulli(self, logic):
151
+ jax_discrete, jax_param = self.discrete(logic)
152
+
153
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
154
+ prob = jnp.stack([1.0 - prob, prob], axis=-1)
155
+ sample = jax_discrete(key, prob, param)
156
+ return sample
157
+
158
+ return _jax_wrapped_calc_bernoulli_approx, jax_param
159
+
160
+ def poisson(self, logic):
161
+
162
+ def _jax_wrapped_calc_poisson_exact(key, rate, param):
163
+ return random.poisson(key=key, lam=rate, dtype=logic.INT)
164
+
165
+ return _jax_wrapped_calc_poisson_exact, None
166
+
167
+ def geometric(self, logic):
168
+ if logic.verbose:
169
+ raise_warning('Using the replacement rule: '
170
+ 'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
171
+
172
+ jax_floor, jax_param = logic.floor()
173
+
174
+ def _jax_wrapped_calc_geometric_approx(key, prob, param):
175
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
176
+ sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
177
+ return sample
178
+
179
+ return _jax_wrapped_calc_geometric_approx, jax_param
180
+
181
+
182
+ class GumbelSoftmax(RandomSampling):
183
+ '''Random sampling of discrete variables using Gumbel-softmax trick.'''
89
184
 
185
+ def discrete(self, logic):
186
+ if logic.verbose:
187
+ raise_warning('Using the replacement rule: '
188
+ 'Discrete(p) --> Gumbel-softmax(p)')
189
+
190
+ jax_argmax, jax_param = logic.argmax()
191
+
192
+ def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
193
+ Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
194
+ sample = Gumbel01 + jnp.log(prob + logic.eps)
195
+ sample = jax_argmax(sample, axis=-1, param=param)
196
+ return sample
197
+
198
+ return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
199
+
200
+
201
+ class Determinization(RandomSampling):
202
+ '''Random sampling of variables using their deterministic mean estimate.'''
90
203
 
204
+ def discrete(self, logic):
205
+ if logic.verbose:
206
+ raise_warning('Using the replacement rule: '
207
+ 'Discrete(p) --> sum(i * p[i])')
208
+
209
+ def _jax_wrapped_calc_discrete_determinized(key, prob, param):
210
+ literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
211
+ sample = jnp.sum(literals * prob, axis=-1)
212
+ return sample
213
+
214
+ return _jax_wrapped_calc_discrete_determinized, None
215
+
216
+ def poisson(self, logic):
217
+ if logic.verbose:
218
+ raise_warning('Using the replacement rule: Poisson(rate) --> rate')
219
+
220
+ def _jax_wrapped_calc_poisson_determinized(key, rate, param):
221
+ return rate
222
+
223
+ return _jax_wrapped_calc_poisson_determinized, None
224
+
225
+ def geometric(self, logic):
226
+ if logic.verbose:
227
+ raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
228
+
229
+ def _jax_wrapped_calc_geometric_determinized(key, prob, param):
230
+ sample = 1.0 / prob
231
+ return sample
232
+
233
+ return _jax_wrapped_calc_geometric_determinized, None
234
+
235
+
236
+ # ===========================================================================
237
+ # FUZZY LOGIC
238
+ #
239
+ # ===========================================================================
240
+
91
241
  class FuzzyLogic:
92
242
  '''A class representing fuzzy logic in JAX.
93
243
 
@@ -98,9 +248,10 @@ class FuzzyLogic:
98
248
  def __init__(self, tnorm: TNorm=ProductTNorm(),
99
249
  complement: Complement=StandardComplement(),
100
250
  comparison: Comparison=SigmoidComparison(),
251
+ sampling: RandomSampling=GumbelSoftmax(),
101
252
  weight: float=10.0,
102
253
  debias: Optional[Set[str]]=None,
103
- eps: float=1e-10,
254
+ eps: float=1e-15,
104
255
  verbose: bool=False,
105
256
  use64bit: bool=False) -> None:
106
257
  '''Creates a new fuzzy logic in Jax.
@@ -108,8 +259,8 @@ class FuzzyLogic:
108
259
  :param tnorm: fuzzy operator for logical AND
109
260
  :param complement: fuzzy operator for logical NOT
110
261
  :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
262
+ :param sampling: random sampling of non-reparameterizable distributions
111
263
  :param weight: a sharpness parameter for sigmoid and softmax activations
112
- :param error: an error parameter (e.g. floor) (smaller means better accuracy)
113
264
  :param debias: which functions to de-bias approximate on forward pass
114
265
  :param eps: small positive float to mitigate underflow
115
266
  :param verbose: whether to dump replacements and other info to console
@@ -118,6 +269,7 @@ class FuzzyLogic:
118
269
  self.tnorm = tnorm
119
270
  self.complement = complement
120
271
  self.comparison = comparison
272
+ self.sampling = sampling
121
273
  self.weight = float(weight)
122
274
  if debias is None:
123
275
  debias = set()
@@ -142,10 +294,11 @@ class FuzzyLogic:
142
294
  f' tnorm ={type(self.tnorm).__name__}\n'
143
295
  f' complement ={type(self.complement).__name__}\n'
144
296
  f' comparison ={type(self.comparison).__name__}\n'
297
+ f' sampling ={type(self.sampling).__name__}\n'
145
298
  f' sigmoid_weight={self.weight}\n'
146
299
  f' cpfs_to_debias={self.debias}\n'
147
300
  f' underflow_tol ={self.eps}\n'
148
- f' use64bit ={self.use64bit}')
301
+ f' use_64_bit ={self.use64bit}')
149
302
 
150
303
  # ===========================================================================
151
304
  # logical operators
@@ -419,7 +572,7 @@ class FuzzyLogic:
419
572
  # ===========================================================================
420
573
 
421
574
  @staticmethod
422
- def _literals(shape, axis):
575
+ def enumerate_literals(shape, axis):
423
576
  literals = jnp.arange(shape[axis])
424
577
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
425
578
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -434,7 +587,7 @@ class FuzzyLogic:
434
587
  debias = 'argmax' in self.debias
435
588
 
436
589
  def _jax_wrapped_calc_argmax_approx(x, axis, param):
437
- literals = FuzzyLogic._literals(x.shape, axis=axis)
590
+ literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
438
591
  soft_max = jax.nn.softmax(param * x, axis=axis)
439
592
  sample = jnp.sum(literals * soft_max, axis=axis)
440
593
  if debias:
@@ -468,7 +621,7 @@ class FuzzyLogic:
468
621
  def _jax_wrapped_calc_if_approx(c, a, b, param):
469
622
  sample = c * a + (1.0 - c) * b
470
623
  if debias:
471
- hard_sample = jnp.select([c, ~c], [a, b])
624
+ hard_sample = jnp.where(c > 0.5, a, b)
472
625
  sample += jax.lax.stop_gradient(hard_sample - sample)
473
626
  return sample
474
627
 
@@ -483,7 +636,7 @@ class FuzzyLogic:
483
636
  debias = 'switch' in self.debias
484
637
 
485
638
  def _jax_wrapped_calc_switch_approx(pred, cases, param):
486
- literals = FuzzyLogic._literals(cases.shape, axis=0)
639
+ literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
487
640
  pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
488
641
  proximity = -jnp.abs(pred - literals)
489
642
  soft_case = jax.nn.softmax(param * proximity, axis=0)
@@ -502,44 +655,24 @@ class FuzzyLogic:
502
655
  # random variables
503
656
  # ===========================================================================
504
657
 
505
- def _gumbel_softmax(self, key, prob):
506
- Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=self.REAL)
507
- sample = Gumbel01 + jnp.log(prob + self.eps)
508
- return sample
509
-
658
+ def discrete(self):
659
+ return self.sampling.discrete(self)
660
+
510
661
  def bernoulli(self):
511
- if self.verbose:
512
- raise_warning('Using the replacement rule: '
513
- 'Bernoulli(p) --> Gumbel-softmax(p)')
514
-
515
- jax_gs = self._gumbel_softmax
516
- jax_argmax, jax_param = self.argmax()
517
-
518
- def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
519
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
520
- sample = jax_gs(key, prob)
521
- sample = jax_argmax(sample, axis=-1, param=param)
522
- return sample
523
-
524
- return _jax_wrapped_calc_bernoulli_approx, jax_param
662
+ return self.sampling.bernoulli(self)
525
663
 
526
- def discrete(self):
527
- if self.verbose:
528
- raise_warning('Using the replacement rule: '
529
- 'Discrete(p) --> Gumbel-softmax(p)')
530
-
531
- jax_gs = self._gumbel_softmax
532
- jax_argmax, jax_param = self.argmax()
533
-
534
- def _jax_wrapped_calc_discrete_approx(key, prob, param):
535
- sample = jax_gs(key, prob)
536
- sample = jax_argmax(sample, axis=-1, param=param)
537
- return sample
538
-
539
- return _jax_wrapped_calc_discrete_approx, jax_param
540
-
664
+ def poisson(self):
665
+ return self.sampling.poisson(self)
666
+
667
+ def geometric(self):
668
+ return self.sampling.geometric(self)
669
+
541
670
 
671
+ # ===========================================================================
542
672
  # UNIT TESTS
673
+ #
674
+ # ===========================================================================
675
+
543
676
  logic = FuzzyLogic()
544
677
  w = 100.0
545
678
 
@@ -598,13 +731,14 @@ def _test_random():
598
731
  key = random.PRNGKey(42)
599
732
  _bernoulli, _ = logic.bernoulli()
600
733
  _discrete, _ = logic.discrete()
734
+ _geometric, _ = logic.geometric()
601
735
 
602
736
  def bern(n):
603
737
  prob = jnp.asarray([0.3] * n)
604
738
  sample = _bernoulli(key, prob, w)
605
739
  return sample
606
740
 
607
- samples = bern(5000)
741
+ samples = bern(50000)
608
742
  print(jnp.mean(samples))
609
743
 
610
744
  def disc(n):
@@ -613,10 +747,18 @@ def _test_random():
613
747
  sample = _discrete(key, prob, w)
614
748
  return sample
615
749
 
616
- samples = disc(5000)
750
+ samples = disc(50000)
617
751
  samples = jnp.round(samples)
618
752
  print([jnp.mean(samples == i) for i in range(3)])
619
-
753
+
754
+ def geom(n):
755
+ prob = jnp.asarray([0.3] * n)
756
+ sample = _geometric(key, prob, w)
757
+ return sample
758
+
759
+ samples = geom(50000)
760
+ print(jnp.mean(samples))
761
+
620
762
 
621
763
  def _test_rounding():
622
764
  print('testing rounding')
@@ -2,54 +2,51 @@ from ast import literal_eval
2
2
  from collections import deque
3
3
  import configparser
4
4
  from enum import Enum
5
+ import os
6
+ import sys
7
+ import time
8
+ import traceback
9
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
10
+
5
11
  import haiku as hk
6
12
  import jax
13
+ import jax.nn.initializers as initializers
7
14
  import jax.numpy as jnp
8
15
  import jax.random as random
9
- import jax.nn.initializers as initializers
10
16
  import numpy as np
11
17
  import optax
12
- import os
13
- import sys
14
18
  import termcolor
15
- import time
16
- import traceback
17
19
  from tqdm import tqdm
18
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
19
20
 
20
- Activation = Callable[[jnp.ndarray], jnp.ndarray]
21
- Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
22
- Kwargs = Dict[str, Any]
23
- Pytree = Any
24
-
25
- from pyRDDLGym.core.debug.exception import raise_warning
26
-
27
- from pyRDDLGym_jax import __version__
28
-
29
- # try to import matplotlib, if failed then skip plotting
30
- try:
31
- import matplotlib
32
- import matplotlib.pyplot as plt
33
- matplotlib.use('TkAgg')
34
- except Exception:
35
- raise_warning('failed to import matplotlib: '
36
- 'plotting functionality will be disabled.', 'red')
37
- traceback.print_exc()
38
- plt = None
39
-
40
21
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
41
22
  from pyRDDLGym.core.debug.logger import Logger
42
23
  from pyRDDLGym.core.debug.exception import (
24
+ raise_warning,
43
25
  RDDLNotImplementedError,
44
26
  RDDLUndefinedVariableError,
45
27
  RDDLTypeError
46
28
  )
47
29
  from pyRDDLGym.core.policy import BaseAgent
48
30
 
49
- from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
31
+ from pyRDDLGym_jax import __version__
50
32
  from pyRDDLGym_jax.core import logic
33
+ from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
51
34
  from pyRDDLGym_jax.core.logic import FuzzyLogic
52
35
 
36
+ # try to import matplotlib, if failed then skip plotting
37
+ try:
38
+ import matplotlib.pyplot as plt
39
+ except Exception:
40
+ raise_warning('failed to import matplotlib: '
41
+ 'plotting functionality will be disabled.', 'red')
42
+ traceback.print_exc()
43
+ plt = None
44
+
45
+ Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
+ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
+ Kwargs = Dict[str, Any]
48
+ Pytree = Any
49
+
53
50
 
54
51
  # ***********************************************************************
55
52
  # CONFIG FILE MANAGEMENT
@@ -104,9 +101,12 @@ def _load_config(config, args):
104
101
  comp_kwargs = model_args.get('complement_kwargs', {})
105
102
  compare_name = model_args.get('comparison', 'SigmoidComparison')
106
103
  compare_kwargs = model_args.get('comparison_kwargs', {})
104
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
107
106
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
108
107
  logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
109
108
  logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
109
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
110
110
 
111
111
  # read the policy settings
112
112
  plan_method = planner_args.pop('method')
@@ -184,18 +184,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
184
184
  #
185
185
  # ***********************************************************************
186
186
 
187
- def _function_discrete_approx_named(logic):
188
- jax_discrete, jax_param = logic.discrete()
189
-
190
- def _jax_wrapped_discrete_calc_approx(key, prob, params):
191
- sample = jax_discrete(key, prob, params)
192
- out_of_bounds = jnp.logical_not(jnp.logical_and(
193
- jnp.all(prob >= 0),
194
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
195
- return sample, out_of_bounds
196
-
197
- return _jax_wrapped_discrete_calc_approx, jax_param
198
-
199
187
 
200
188
  class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
201
189
  '''Compiles a RDDL AST representation to an equivalent JAX representation.
@@ -271,7 +259,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
271
259
  self.IF_HELPER = logic.control_if()
272
260
  self.SWITCH_HELPER = logic.control_switch()
273
261
  self.BERNOULLI_HELPER = logic.bernoulli()
274
- self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
262
+ self.DISCRETE_HELPER = logic.discrete()
263
+ self.POISSON_HELPER = logic.poisson()
264
+ self.GEOMETRIC_HELPER = logic.geometric()
275
265
 
276
266
  def _jax_stop_grad(self, jax_expr):
277
267
 
@@ -469,7 +459,8 @@ class JaxStraightLinePlan(JaxPlan):
469
459
  f' wrap_non_bool ={self._wrap_non_bool}\n'
470
460
  f'constraint-sat strategy (complex):\n'
471
461
  f' wrap_softmax ={self._wrap_softmax}\n'
472
- f' use_new_projection ={self._use_new_projection}')
462
+ f' use_new_projection ={self._use_new_projection}\n'
463
+ f' max_projection_iters ={self._max_constraint_iter}')
473
464
 
474
465
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
475
466
  _bounds: Bounds,
@@ -1348,8 +1339,18 @@ class JaxBackpropPlanner:
1348
1339
  map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1349
1340
  except Exception as _:
1350
1341
  devices_short = 'N/A'
1342
+ LOGO = \
1343
+ """
1344
+ __ ______ __ __ ______ __ ______ __ __
1345
+ /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1346
+ _\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
1347
+ /\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
1348
+ \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1349
+ """
1350
+
1351
1351
  print('\n'
1352
- f'JAX Planner version {__version__}\n'
1352
+ f'{LOGO}\n'
1353
+ f'Version {__version__}\n'
1353
1354
  f'Python {sys.version}\n'
1354
1355
  f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1355
1356
  f'optax {optax.__version__}, haiku {hk.__version__}, '
@@ -1711,6 +1712,14 @@ class JaxBackpropPlanner:
1711
1712
  hyperparam_value = float(policy_hyperparams)
1712
1713
  policy_hyperparams = {action: hyperparam_value
1713
1714
  for action in self.rddl.action_fluents}
1715
+
1716
+ # fill in missing entries
1717
+ elif isinstance(policy_hyperparams, dict):
1718
+ for action in self.rddl.action_fluents:
1719
+ if action not in policy_hyperparams:
1720
+ raise_warning(f'policy_hyperparams[{action}] is not set, '
1721
+ 'setting 1.0 which could be suboptimal.')
1722
+ policy_hyperparams[action] = 1.0
1714
1723
 
1715
1724
  # print summary of parameters:
1716
1725
  if print_summary:
@@ -1772,6 +1781,7 @@ class JaxBackpropPlanner:
1772
1781
  rolling_test_loss = RollingMean(test_rolling_window)
1773
1782
  log = {}
1774
1783
  status = JaxPlannerStatus.NORMAL
1784
+ is_all_zero_fn = lambda x: np.allclose(x, 0)
1775
1785
 
1776
1786
  # initialize plot area
1777
1787
  if plot_step is None or plot_step <= 0 or plt is None:
@@ -1786,6 +1796,7 @@ class JaxBackpropPlanner:
1786
1796
  iters = range(epochs)
1787
1797
  if print_progress:
1788
1798
  iters = tqdm(iters, total=100, position=tqdm_position)
1799
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1789
1800
 
1790
1801
  for it in iters:
1791
1802
  status = JaxPlannerStatus.NORMAL
@@ -1799,7 +1810,7 @@ class JaxBackpropPlanner:
1799
1810
 
1800
1811
  # no progress
1801
1812
  grad_norm_zero, _ = jax.tree_util.tree_flatten(
1802
- jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1813
+ jax.tree_map(is_all_zero_fn, train_log['grad']))
1803
1814
  if np.all(grad_norm_zero):
1804
1815
  status = JaxPlannerStatus.NO_PROGRESS
1805
1816
 
@@ -1843,8 +1854,9 @@ class JaxBackpropPlanner:
1843
1854
  if print_progress:
1844
1855
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1845
1856
  iters.set_description(
1846
- f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
1847
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
1857
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1858
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1859
+ f'{status.value} status')
1848
1860
 
1849
1861
  # reached computation budget
1850
1862
  if elapsed >= train_seconds:
@@ -1904,7 +1916,7 @@ class JaxBackpropPlanner:
1904
1916
  f' iterations ={it}\n'
1905
1917
  f' best_objective={-best_loss}\n'
1906
1918
  f' best_grad_norm={grad_norm}\n'
1907
- f'diagnosis: {diagnosis}\n')
1919
+ f' diagnosis: {diagnosis}\n')
1908
1920
 
1909
1921
  def _perform_diagnosis(self, last_iter_improve,
1910
1922
  train_return, test_return, best_return, grad_norm):
@@ -2116,7 +2128,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2116
2128
  @jax.jit
2117
2129
  def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2118
2130
  return (-1.0 / beta) * jax.scipy.special.logsumexp(
2119
- -beta * returns, b=1.0 / returns.size)
2131
+ -beta * returns, b=1.0 / returns.size)
2120
2132
 
2121
2133
 
2122
2134
  @jax.jit
@@ -1,7 +1,8 @@
1
- import jax
2
1
  import time
3
2
  from typing import Dict, Optional
4
3
 
4
+ import jax
5
+
5
6
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
6
7
  from pyRDDLGym.core.debug.exception import (
7
8
  RDDLActionPreconditionNotSatisfiedError,
@@ -1,20 +1,18 @@
1
- from bayes_opt import BayesianOptimization
2
- from bayes_opt.util import UtilityFunction
3
1
  from copy import deepcopy
4
2
  import csv
5
3
  import datetime
6
- import jax
7
4
  from multiprocessing import get_context
8
- import numpy as np
9
5
  import os
10
6
  import time
11
7
  from typing import Any, Callable, Dict, Optional, Tuple
12
-
13
- Kwargs = Dict[str, Any]
14
-
15
8
  import warnings
16
9
  warnings.filterwarnings("ignore")
17
10
 
11
+ from bayes_opt import BayesianOptimization
12
+ from bayes_opt.util import UtilityFunction
13
+ import jax
14
+ import numpy as np
15
+
18
16
  from pyRDDLGym.core.debug.exception import raise_warning
19
17
  from pyRDDLGym.core.env import RDDLEnv
20
18
 
@@ -26,6 +24,8 @@ from pyRDDLGym_jax.core.planner import (
26
24
  JaxOnlineController
27
25
  )
28
26
 
27
+ Kwargs = Dict[str, Any]
28
+
29
29
 
30
30
  # ===============================================================================
31
31
  #
@@ -0,0 +1,276 @@
1
+ Metadata-Version: 2.1
2
+ Name: pyRDDLGym-jax
3
+ Version: 0.4
4
+ Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
+ Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
+ Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
+ Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
+ License: MIT License
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.8
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: pyRDDLGym >=2.0
20
+ Requires-Dist: tqdm >=4.66
21
+ Requires-Dist: bayesian-optimization >=1.4.3
22
+ Requires-Dist: jax >=0.4.12
23
+ Requires-Dist: optax >=0.1.9
24
+ Requires-Dist: dm-haiku >=0.0.10
25
+ Requires-Dist: tensorflow-probability >=0.21.0
26
+
27
+ # pyRDDLGym-jax
28
+
29
+ Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
30
+
31
+ This directory provides:
32
+ 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
33
+ 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
34
+
35
+ > [!NOTE]
36
+ > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
37
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
38
+
39
+ ## Contents
40
+
41
+ - [Installation](#installation)
42
+ - [Running from the Command Line](#running-from-the-command-line)
43
+ - [Running from within Python](#running-from-within-python)
44
+ - [Configuring the Planner](#configuring-the-planner)
45
+ - [Simulation](#simulation)
46
+ - [Manual Gradient Calculation](#manual-gradient-calculation)
47
+ - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
48
+
49
+ ## Installation
50
+
51
+ To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
52
+ - ``pyRDDLGym>=2.0``
53
+ - ``tqdm>=4.66``
54
+ - ``jax>=0.4.12``
55
+ - ``optax>=0.1.9``
56
+ - ``dm-haiku>=0.0.10``
57
+ - ``tensorflow-probability>=0.21.0``
58
+
59
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
60
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
61
+
62
+ You can install this package, together with all of its requirements, via pip:
63
+
64
+ ```shell
65
+ pip install rddlrepository pyRDDLGym-jax
66
+ ```
67
+
68
+ ## Running from the Command Line
69
+
70
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
71
+
72
+ ```shell
73
+ python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
74
+ ```
75
+
76
+ where:
77
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
78
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
79
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
80
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
81
+
82
+ The ``method`` parameter supports three possible modes:
83
+ - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
84
+ - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
85
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
86
+
87
+ A basic run script is also provided to run the automatic hyper-parameter tuning:
88
+
89
+ ```shell
90
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
91
+ ```
92
+
93
+ where:
94
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
95
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
96
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
97
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
98
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
99
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
100
+
101
+ For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
102
+
103
+ ```shell
104
+ python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
105
+ ```
106
+
107
+ After several minutes of optimization, you should get a visualization as follows:
108
+
109
+ <p align="center">
110
+ <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
111
+ </p>
112
+
113
+ ## Running from within Python
114
+
115
+ To run the Jax planner from within a Python application, refer to the following example:
116
+
117
+ ```python
118
+ import pyRDDLGym
119
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
120
+
121
+ # set up the environment (note the vectorized option must be True)
122
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
123
+
124
+ # create the planning algorithm
125
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
126
+ controller = JaxOfflineController(planner, **train_args)
127
+
128
+ # evaluate the planner
129
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
130
+ env.close()
131
+ ```
132
+
133
+ Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
134
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
135
+ The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
136
+
137
+ ## Configuring the Planner
138
+
139
+ The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
140
+ The basic structure of a configuration file is provided below for a straight-line planner:
141
+
142
+ ```ini
143
+ [Model]
144
+ logic='FuzzyLogic'
145
+ logic_kwargs={'weight': 20}
146
+ tnorm='ProductTNorm'
147
+ tnorm_kwargs={}
148
+
149
+ [Optimizer]
150
+ method='JaxStraightLinePlan'
151
+ method_kwargs={}
152
+ optimizer='rmsprop'
153
+ optimizer_kwargs={'learning_rate': 0.001}
154
+ batch_size_train=1
155
+ batch_size_test=1
156
+
157
+ [Training]
158
+ key=42
159
+ epochs=5000
160
+ train_seconds=30
161
+ ```
162
+
163
+ The configuration file contains three sections:
164
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
165
+ and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
166
+ - ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
167
+ - ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
168
+
169
+ For a policy network approach, simply change the ``[Optimizer]`` settings like so:
170
+
171
+ ```ini
172
+ ...
173
+ [Optimizer]
174
+ method='JaxDeepReactivePolicy'
175
+ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
176
+ ...
177
+ ```
178
+
179
+ The configuration file must then be passed to the planner during initialization.
180
+ For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
181
+
182
+ ```python
183
+ from pyRDDLGym_jax.core.planner import load_config
184
+
185
+ # load the config file with planner settings
186
+ planner_args, _, train_args = load_config("/path/to/config.cfg")
187
+
188
+ # create the planning algorithm
189
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
190
+ controller = JaxOfflineController(planner, **train_args)
191
+ ...
192
+ ```
193
+
194
+ ## Simulation
195
+
196
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
197
+
198
+ ```python
199
+ import pyRDDLGym
200
+ from pyRDDLGym.core.policy import RandomAgent
201
+ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
202
+
203
+ # create the environment
204
+ env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
205
+
206
+ # evaluate the random policy
207
+ agent = RandomAgent(action_space=env.action_space,
208
+ num_actions=env.max_allowed_actions)
209
+ agent.evaluate(env, verbose=True, render=True)
210
+ ```
211
+
212
+ For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
213
+ In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
214
+
215
+ ## Manual Gradient Calculation
216
+
217
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
218
+ Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
219
+
220
+ ```python
221
+ import pyRDDLGym
222
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
223
+
224
+ # set up the environment
225
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
226
+
227
+ # create the step function
228
+ compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
229
+ compiled.compile()
230
+ step_fn = compiled.compile_transition()
231
+ ```
232
+
233
+ This will return a JAX compiled (pure) function requiring the following inputs:
234
+ - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
235
+ - ``actions`` is the dictionary of action fluent tensors
236
+ - ``subs`` is the dictionary of state-fluent and non-fluent tensors
237
+ - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
238
+
239
+ The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
240
+ It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
241
+
242
+ Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
243
+ An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
244
+
245
+ ## Citing pyRDDLGym-jax
246
+
247
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
248
+
249
+ ```
250
+ @inproceedings{gimelfarb2024jaxplan,
251
+ title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
252
+ author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
253
+ booktitle={34th International Conference on Automated Planning and Scheduling},
254
+ year={2024},
255
+ url={https://openreview.net/forum?id=7IKtmUpLEH}
256
+ }
257
+ ```
258
+
259
+ The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
260
+
261
+ ```
262
+ @inproceedings{patton2022distributional,
263
+ title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
264
+ author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
265
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
266
+ volume={36},
267
+ number={9},
268
+ pages={9894--9901},
269
+ year={2022}
270
+ }
271
+ ```
272
+
273
+ Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
274
+ - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
275
+ - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
276
+
@@ -1,10 +1,10 @@
1
- pyRDDLGym_jax/__init__.py,sha256=Cl7DWkrPP64Ofc2ILXnudFOdnCuKs2p0Pm7ykZOOPh4,19
1
+ pyRDDLGym_jax/__init__.py,sha256=rexmxcBiCOcwctw4wGvk7UxS9MfZn_1CYXp53SoLKlU,19
2
2
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- pyRDDLGym_jax/core/compiler.py,sha256=m7p0CHOU4Wma0cKMu_WQwfoieIQ2pXD68hZ8BFJ970A,89103
4
- pyRDDLGym_jax/core/logic.py,sha256=zujSHiR5KhTO81E5Zn8Gy_xSzVzfDskFCGvZygFRdMI,21930
5
- pyRDDLGym_jax/core/planner.py,sha256=1BtU1G3rihRZaMfNu0VtbSl1LXEXu6pT75EkF6-WVnM,101827
6
- pyRDDLGym_jax/core/simulator.py,sha256=fp6bep3XwwBWED0w7_4qhiwDjkSka6B2prwdNcPRCMc,8329
7
- pyRDDLGym_jax/core/tuning.py,sha256=Dv0YyOgGnej-zdVymWdkVg0MZjm2lNRfr7gySzFOeow,29589
3
+ pyRDDLGym_jax/core/compiler.py,sha256=SnDN3-J84Wv_YVHoDmfM_U4Ob8uaFLGX4vEaeWC-ERY,90037
4
+ pyRDDLGym_jax/core/logic.py,sha256=o1YAjMnXfi8gwb42kAigBmaf9uIYUWal9__FEkWohrk,26733
5
+ pyRDDLGym_jax/core/planner.py,sha256=Hrwfn88bUu1LNZcnFC5psHPzcIUbPeF4Rn1pFO6_qH0,102655
6
+ pyRDDLGym_jax/core/simulator.py,sha256=hWv6pr-4V-SSCzBYgdIPmKdUDMalft-Zh6dzOo5O9-0,8331
7
+ pyRDDLGym_jax/core/tuning.py,sha256=D_kD8wjqMroCdtjE9eksR2UqrqXJqazsAKrMEHwPxYM,29589
8
8
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
10
10
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
@@ -37,8 +37,8 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
37
37
  pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
38
38
  pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
39
39
  pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
40
- pyRDDLGym_jax-0.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
41
- pyRDDLGym_jax-0.3.dist-info/METADATA,sha256=e_1MlMdQoqQHW-KA2OSIZzIAQyfe-jDtMOxkIyhmLmI,1085
42
- pyRDDLGym_jax-0.3.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
43
- pyRDDLGym_jax-0.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
44
- pyRDDLGym_jax-0.3.dist-info/RECORD,,
40
+ pyRDDLGym_jax-0.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
41
+ pyRDDLGym_jax-0.4.dist-info/METADATA,sha256=-Kf8PLxf_7MiiYXzlZAf31kV1pT-Rurc7QY7dT3Fwk0,12857
42
+ pyRDDLGym_jax-0.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
43
+ pyRDDLGym_jax-0.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
44
+ pyRDDLGym_jax-0.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,26 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: pyRDDLGym-jax
3
- Version: 0.3
4
- Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
5
- Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
- Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
- Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
- License: MIT License
9
- Classifier: Development Status :: 3 - Alpha
10
- Classifier: Intended Audience :: Science/Research
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Natural Language :: English
13
- Classifier: Operating System :: OS Independent
14
- Classifier: Programming Language :: Python :: 3
15
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
- Requires-Python: >=3.8
17
- License-File: LICENSE
18
- Requires-Dist: pyRDDLGym >=2.0
19
- Requires-Dist: tqdm >=4.66
20
- Requires-Dist: bayesian-optimization >=1.4.3
21
- Requires-Dist: jax >=0.4.12
22
- Requires-Dist: optax >=0.1.9
23
- Requires-Dist: dm-haiku >=0.0.10
24
- Requires-Dist: tensorflow >=2.13.0
25
- Requires-Dist: tensorflow-probability >=0.21.0
26
-