pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. pyRDDLGym_jax/__init__.py +1 -0
  2. pyRDDLGym_jax/core/compiler.py +444 -221
  3. pyRDDLGym_jax/core/logic.py +129 -62
  4. pyRDDLGym_jax/core/planner.py +965 -394
  5. pyRDDLGym_jax/core/simulator.py +5 -7
  6. pyRDDLGym_jax/core/tuning.py +29 -15
  7. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  8. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
  9. pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
  10. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  11. pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
  12. pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
  13. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  14. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  15. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  16. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  17. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  18. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  19. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  20. pyRDDLGym_jax/examples/run_gym.py +3 -7
  21. pyRDDLGym_jax/examples/run_plan.py +10 -5
  22. pyRDDLGym_jax/examples/run_scipy.py +61 -0
  23. pyRDDLGym_jax/examples/run_tune.py +8 -3
  24. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
  25. pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
  26. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
  27. pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
  28. pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
  29. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  30. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  34. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  35. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  36. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  37. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  38. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
  39. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import jax.numpy as jnp
4
4
  import jax.random as random
5
5
  import jax.scipy as scipy
6
6
  import traceback
7
- from typing import Callable, Dict, List
7
+ from typing import Any, Callable, Dict, List, Optional
8
8
 
9
9
  from pyRDDLGym.core.debug.exception import raise_warning
10
10
 
@@ -13,8 +13,8 @@ try:
13
13
  from tensorflow_probability.substrates import jax as tfp
14
14
  except Exception:
15
15
  raise_warning('Failed to import tensorflow-probability: '
16
- 'compilation of some complex distributions will not work.',
17
- 'red')
16
+ 'compilation of some complex distributions '
17
+ '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
18
18
  traceback.print_exc()
19
19
  tfp = None
20
20
 
@@ -32,6 +32,98 @@ from pyRDDLGym.core.debug.logger import Logger
32
32
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
33
33
 
34
34
 
35
+ # ===========================================================================
36
+ # EXACT RDDL TO JAX COMPILATION RULES
37
+ # ===========================================================================
38
+
39
+ def _function_unary_exact_named(op, name):
40
+
41
+ def _jax_wrapped_unary_fn_exact(x, param):
42
+ return op(x)
43
+
44
+ return _jax_wrapped_unary_fn_exact
45
+
46
+
47
+ def _function_unary_exact_named_gamma():
48
+
49
+ def _jax_wrapped_unary_gamma_exact(x, param):
50
+ return jnp.exp(scipy.special.gammaln(x))
51
+
52
+ return _jax_wrapped_unary_gamma_exact
53
+
54
+
55
+ def _function_binary_exact_named(op, name):
56
+
57
+ def _jax_wrapped_binary_fn_exact(x, y, param):
58
+ return op(x, y)
59
+
60
+ return _jax_wrapped_binary_fn_exact
61
+
62
+
63
+ def _function_binary_exact_named_implies():
64
+
65
+ def _jax_wrapped_binary_implies_exact(x, y, param):
66
+ return jnp.logical_or(jnp.logical_not(x), y)
67
+
68
+ return _jax_wrapped_binary_implies_exact
69
+
70
+
71
+ def _function_binary_exact_named_log():
72
+
73
+ def _jax_wrapped_binary_log_exact(x, y, param):
74
+ return jnp.log(x) / jnp.log(y)
75
+
76
+ return _jax_wrapped_binary_log_exact
77
+
78
+
79
+ def _function_aggregation_exact_named(op, name):
80
+
81
+ def _jax_wrapped_aggregation_fn_exact(x, axis, param):
82
+ return op(x, axis=axis)
83
+
84
+ return _jax_wrapped_aggregation_fn_exact
85
+
86
+
87
+ def _function_if_exact_named():
88
+
89
+ def _jax_wrapped_if_exact(c, a, b, param):
90
+ return jnp.where(c, a, b)
91
+
92
+ return _jax_wrapped_if_exact
93
+
94
+
95
+ def _function_switch_exact_named():
96
+
97
+ def _jax_wrapped_switch_exact(pred, cases, param):
98
+ pred = pred[jnp.newaxis, ...]
99
+ sample = jnp.take_along_axis(cases, pred, axis=0)
100
+ assert sample.shape[0] == 1
101
+ return sample[0, ...]
102
+
103
+ return _jax_wrapped_switch_exact
104
+
105
+
106
+ def _function_bernoulli_exact_named():
107
+
108
+ def _jax_wrapped_bernoulli_exact(key, prob, param):
109
+ return random.bernoulli(key, prob)
110
+
111
+ return _jax_wrapped_bernoulli_exact
112
+
113
+
114
+ def _function_discrete_exact_named():
115
+
116
+ def _jax_wrapped_discrete_exact(key, prob, param):
117
+ logits = jnp.log(prob)
118
+ sample = random.categorical(key=key, logits=logits, axis=-1)
119
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
120
+ jnp.all(prob >= 0),
121
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
122
+ return sample, out_of_bounds
123
+
124
+ return _jax_wrapped_discrete_exact
125
+
126
+
35
127
  class JaxRDDLCompiler:
36
128
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
37
129
  All operations are identical to their numpy equivalents.
@@ -39,10 +131,97 @@ class JaxRDDLCompiler:
39
131
 
40
132
  MODEL_PARAM_TAG_SEPARATOR = '___'
41
133
 
134
+ # ===========================================================================
135
+ # EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
136
+ # ===========================================================================
137
+
138
+ EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
139
+
140
+ EXACT_RDDL_TO_JAX_ARITHMETIC = {
141
+ '+': _function_binary_exact_named(jnp.add, 'add'),
142
+ '-': _function_binary_exact_named(jnp.subtract, 'subtract'),
143
+ '*': _function_binary_exact_named(jnp.multiply, 'multiply'),
144
+ '/': _function_binary_exact_named(jnp.divide, 'divide')
145
+ }
146
+
147
+ EXACT_RDDL_TO_JAX_RELATIONAL = {
148
+ '>=': _function_binary_exact_named(jnp.greater_equal, 'greater_equal'),
149
+ '<=': _function_binary_exact_named(jnp.less_equal, 'less_equal'),
150
+ '<': _function_binary_exact_named(jnp.less, 'less'),
151
+ '>': _function_binary_exact_named(jnp.greater, 'greater'),
152
+ '==': _function_binary_exact_named(jnp.equal, 'equal'),
153
+ '~=': _function_binary_exact_named(jnp.not_equal, 'not_equal')
154
+ }
155
+
156
+ EXACT_RDDL_TO_JAX_LOGICAL = {
157
+ '^': _function_binary_exact_named(jnp.logical_and, 'and'),
158
+ '&': _function_binary_exact_named(jnp.logical_and, 'and'),
159
+ '|': _function_binary_exact_named(jnp.logical_or, 'or'),
160
+ '~': _function_binary_exact_named(jnp.logical_xor, 'xor'),
161
+ '=>': _function_binary_exact_named_implies(),
162
+ '<=>': _function_binary_exact_named(jnp.equal, 'iff')
163
+ }
164
+
165
+ EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
166
+
167
+ EXACT_RDDL_TO_JAX_AGGREGATION = {
168
+ 'sum': _function_aggregation_exact_named(jnp.sum, 'sum'),
169
+ 'avg': _function_aggregation_exact_named(jnp.mean, 'avg'),
170
+ 'prod': _function_aggregation_exact_named(jnp.prod, 'prod'),
171
+ 'minimum': _function_aggregation_exact_named(jnp.min, 'minimum'),
172
+ 'maximum': _function_aggregation_exact_named(jnp.max, 'maximum'),
173
+ 'forall': _function_aggregation_exact_named(jnp.all, 'forall'),
174
+ 'exists': _function_aggregation_exact_named(jnp.any, 'exists'),
175
+ 'argmin': _function_aggregation_exact_named(jnp.argmin, 'argmin'),
176
+ 'argmax': _function_aggregation_exact_named(jnp.argmax, 'argmax')
177
+ }
178
+
179
+ EXACT_RDDL_TO_JAX_UNARY = {
180
+ 'abs': _function_unary_exact_named(jnp.abs, 'abs'),
181
+ 'sgn': _function_unary_exact_named(jnp.sign, 'sgn'),
182
+ 'round': _function_unary_exact_named(jnp.round, 'round'),
183
+ 'floor': _function_unary_exact_named(jnp.floor, 'floor'),
184
+ 'ceil': _function_unary_exact_named(jnp.ceil, 'ceil'),
185
+ 'cos': _function_unary_exact_named(jnp.cos, 'cos'),
186
+ 'sin': _function_unary_exact_named(jnp.sin, 'sin'),
187
+ 'tan': _function_unary_exact_named(jnp.tan, 'tan'),
188
+ 'acos': _function_unary_exact_named(jnp.arccos, 'acos'),
189
+ 'asin': _function_unary_exact_named(jnp.arcsin, 'asin'),
190
+ 'atan': _function_unary_exact_named(jnp.arctan, 'atan'),
191
+ 'cosh': _function_unary_exact_named(jnp.cosh, 'cosh'),
192
+ 'sinh': _function_unary_exact_named(jnp.sinh, 'sinh'),
193
+ 'tanh': _function_unary_exact_named(jnp.tanh, 'tanh'),
194
+ 'exp': _function_unary_exact_named(jnp.exp, 'exp'),
195
+ 'ln': _function_unary_exact_named(jnp.log, 'ln'),
196
+ 'sqrt': _function_unary_exact_named(jnp.sqrt, 'sqrt'),
197
+ 'lngamma': _function_unary_exact_named(scipy.special.gammaln, 'lngamma'),
198
+ 'gamma': _function_unary_exact_named_gamma()
199
+ }
200
+
201
+ EXACT_RDDL_TO_JAX_BINARY = {
202
+ 'div': _function_binary_exact_named(jnp.floor_divide, 'div'),
203
+ 'mod': _function_binary_exact_named(jnp.mod, 'mod'),
204
+ 'fmod': _function_binary_exact_named(jnp.mod, 'fmod'),
205
+ 'min': _function_binary_exact_named(jnp.minimum, 'min'),
206
+ 'max': _function_binary_exact_named(jnp.maximum, 'max'),
207
+ 'pow': _function_binary_exact_named(jnp.power, 'pow'),
208
+ 'log': _function_binary_exact_named_log(),
209
+ 'hypot': _function_binary_exact_named(jnp.hypot, 'hypot'),
210
+ }
211
+
212
+ EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
213
+
214
+ EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
215
+
216
+ EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
217
+
218
+ EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
219
+
42
220
  def __init__(self, rddl: RDDLLiftedModel,
43
221
  allow_synchronous_state: bool=True,
44
- logger: Logger=None,
45
- use64bit: bool=False) -> None:
222
+ logger: Optional[Logger]=None,
223
+ use64bit: bool=False,
224
+ compile_non_fluent_exact: bool=True) -> None:
46
225
  '''Creates a new RDDL to Jax compiler.
47
226
 
48
227
  :param rddl: the RDDL model to compile into Jax
@@ -50,11 +229,14 @@ class JaxRDDLCompiler:
50
229
  on each other
51
230
  :param logger: to log information about compilation to file
52
231
  :param use64bit: whether to use 64 bit arithmetic
232
+ :param compile_non_fluent_exact: whether non-fluent expressions
233
+ are always compiled using exact JAX expressions.
53
234
  '''
54
235
  self.rddl = rddl
55
236
  self.logger = logger
56
237
  # jax.config.update('jax_log_compiles', True) # for testing ONLY
57
238
 
239
+ self.use64bit = use64bit
58
240
  if use64bit:
59
241
  self.INT = jnp.int64
60
242
  self.REAL = jnp.float64
@@ -62,6 +244,7 @@ class JaxRDDLCompiler:
62
244
  else:
63
245
  self.INT = jnp.int32
64
246
  self.REAL = jnp.float32
247
+ jax.config.update('jax_enable_x64', False)
65
248
  self.ONE = jnp.asarray(1, dtype=self.INT)
66
249
  self.JAX_TYPES = {
67
250
  'int': self.INT,
@@ -70,17 +253,16 @@ class JaxRDDLCompiler:
70
253
  }
71
254
 
72
255
  # compile initial values
73
- if self.logger is not None:
74
- self.logger.clear()
75
- initializer = RDDLValueInitializer(rddl, logger=self.logger)
256
+ initializer = RDDLValueInitializer(rddl)
76
257
  self.init_values = initializer.initialize()
77
258
 
78
259
  # compute dependency graph for CPFs and sort them by evaluation order
79
- sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state, logger=self.logger)
260
+ sorter = RDDLLevelAnalysis(
261
+ rddl, allow_synchronous_state=allow_synchronous_state)
80
262
  self.levels = sorter.compute_levels()
81
263
 
82
264
  # trace expressions to cache information to be used later
83
- tracer = RDDLObjectsTracer(rddl, logger=self.logger, cpf_levels=self.levels)
265
+ tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
84
266
  self.traced = tracer.trace()
85
267
 
86
268
  # extract the box constraints on actions
@@ -92,92 +274,42 @@ class JaxRDDLCompiler:
92
274
  constraints = RDDLConstraints(simulator, vectorized=True)
93
275
  self.constraints = constraints
94
276
 
95
- # basic operations
96
- self.NEGATIVE = lambda x, param: jnp.negative(x)
97
- self.ARITHMETIC_OPS = {
98
- '+': lambda x, y, param: jnp.add(x, y),
99
- '-': lambda x, y, param: jnp.subtract(x, y),
100
- '*': lambda x, y, param: jnp.multiply(x, y),
101
- '/': lambda x, y, param: jnp.divide(x, y)
102
- }
103
- self.RELATIONAL_OPS = {
104
- '>=': lambda x, y, param: jnp.greater_equal(x, y),
105
- '<=': lambda x, y, param: jnp.less_equal(x, y),
106
- '<': lambda x, y, param: jnp.less(x, y),
107
- '>': lambda x, y, param: jnp.greater(x, y),
108
- '==': lambda x, y, param: jnp.equal(x, y),
109
- '~=': lambda x, y, param: jnp.not_equal(x, y)
110
- }
111
- self.LOGICAL_NOT = lambda x, param: jnp.logical_not(x)
112
- self.LOGICAL_OPS = {
113
- '^': lambda x, y, param: jnp.logical_and(x, y),
114
- '&': lambda x, y, param: jnp.logical_and(x, y),
115
- '|': lambda x, y, param: jnp.logical_or(x, y),
116
- '~': lambda x, y, param: jnp.logical_xor(x, y),
117
- '=>': lambda x, y, param: jnp.logical_or(jnp.logical_not(x), y),
118
- '<=>': lambda x, y, param: jnp.equal(x, y)
119
- }
120
- self.AGGREGATION_OPS = {
121
- 'sum': lambda x, axis, param: jnp.sum(x, axis=axis),
122
- 'avg': lambda x, axis, param: jnp.mean(x, axis=axis),
123
- 'prod': lambda x, axis, param: jnp.prod(x, axis=axis),
124
- 'minimum': lambda x, axis, param: jnp.min(x, axis=axis),
125
- 'maximum': lambda x, axis, param: jnp.max(x, axis=axis),
126
- 'forall': lambda x, axis, param: jnp.all(x, axis=axis),
127
- 'exists': lambda x, axis, param: jnp.any(x, axis=axis),
128
- 'argmin': lambda x, axis, param: jnp.argmin(x, axis=axis),
129
- 'argmax': lambda x, axis, param: jnp.argmax(x, axis=axis)
130
- }
277
+ # basic operations - these can be override in subclasses
278
+ self.compile_non_fluent_exact = compile_non_fluent_exact
279
+ self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
280
+ self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
281
+ self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
282
+ self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
283
+ self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
284
+ self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
131
285
  self.AGGREGATION_BOOL = {'forall', 'exists'}
132
- self.KNOWN_UNARY = {
133
- 'abs': lambda x, param: jnp.abs(x),
134
- 'sgn': lambda x, param: jnp.sign(x),
135
- 'round': lambda x, param: jnp.round(x),
136
- 'floor': lambda x, param: jnp.floor(x),
137
- 'ceil': lambda x, param: jnp.ceil(x),
138
- 'cos': lambda x, param: jnp.cos(x),
139
- 'sin': lambda x, param: jnp.sin(x),
140
- 'tan': lambda x, param: jnp.tan(x),
141
- 'acos': lambda x, param: jnp.arccos(x),
142
- 'asin': lambda x, param: jnp.arcsin(x),
143
- 'atan': lambda x, param: jnp.arctan(x),
144
- 'cosh': lambda x, param: jnp.cosh(x),
145
- 'sinh': lambda x, param: jnp.sinh(x),
146
- 'tanh': lambda x, param: jnp.tanh(x),
147
- 'exp': lambda x, param: jnp.exp(x),
148
- 'ln': lambda x, param: jnp.log(x),
149
- 'sqrt': lambda x, param: jnp.sqrt(x),
150
- 'lngamma': lambda x, param: scipy.special.gammaln(x),
151
- 'gamma': lambda x, param: jnp.exp(scipy.special.gammaln(x))
152
- }
153
- self.KNOWN_BINARY = {
154
- 'div': lambda x, y, param: jnp.floor_divide(x, y),
155
- 'mod': lambda x, y, param: jnp.mod(x, y),
156
- 'fmod': lambda x, y, param: jnp.mod(x, y),
157
- 'min': lambda x, y, param: jnp.minimum(x, y),
158
- 'max': lambda x, y, param: jnp.maximum(x, y),
159
- 'pow': lambda x, y, param: jnp.power(x, y),
160
- 'log': lambda x, y, param: jnp.log(x) / jnp.log(y),
161
- 'hypot': lambda x, y, param: jnp.hypot(x, y)
162
- }
163
-
286
+ self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
287
+ self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
288
+ self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
289
+ self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
290
+ self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
291
+ self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
292
+
164
293
  # ===========================================================================
165
294
  # main compilation subroutines
166
295
  # ===========================================================================
167
296
 
168
- def compile(self, log_jax_expr: bool=False) -> None:
297
+ def compile(self, log_jax_expr: bool=False, heading: str='') -> None:
169
298
  '''Compiles the current RDDL into Jax expressions.
170
299
 
171
300
  :param log_jax_expr: whether to pretty-print the compiled Jax functions
172
301
  to the log file
302
+ :param heading: the heading to print before compilation information
173
303
  '''
174
- info = {}
304
+ info = ({}, [])
175
305
  self.invariants = self._compile_constraints(self.rddl.invariants, info)
176
306
  self.preconditions = self._compile_constraints(self.rddl.preconditions, info)
177
307
  self.terminations = self._compile_constraints(self.rddl.terminations, info)
178
308
  self.cpfs = self._compile_cpfs(info)
179
309
  self.reward = self._compile_reward(info)
180
- self.model_params = info
310
+ self.model_params = {key: value
311
+ for (key, (value, *_)) in info[0].items()}
312
+ self.relaxations = info[1]
181
313
 
182
314
  if log_jax_expr and self.logger is not None:
183
315
  printed = self.print_jax()
@@ -189,6 +321,7 @@ class JaxRDDLCompiler:
189
321
  printed_terminals = '\n\n'.join(v for v in printed['terminations'])
190
322
  printed_params = '\n'.join(f'{k}: {v}' for (k, v) in info.items())
191
323
  message = (
324
+ f'[info] {heading}\n'
192
325
  f'[info] compiled JAX CPFs:\n\n'
193
326
  f'{printed_cpfs}\n\n'
194
327
  f'[info] compiled JAX reward:\n\n'
@@ -281,17 +414,21 @@ class JaxRDDLCompiler:
281
414
  return jax_inequalities, jax_equalities
282
415
 
283
416
  def compile_transition(self, check_constraints: bool=False,
284
- constraint_func: bool=False):
417
+ constraint_func: bool=False) -> Callable:
285
418
  '''Compiles the current RDDL model into a JAX transition function that
286
419
  samples the next state.
287
420
 
288
- The signature of the returned function is (key, actions, subs,
289
- model_params), where:
421
+ The arguments of the returned function is:
290
422
  - key is the PRNG key
291
423
  - actions is the dict of action tensors
292
424
  - subs is the dict of current pvar value tensors
293
425
  - model_params is a dict of parameters for the relaxed model.
294
-
426
+
427
+ The returned value of the function is:
428
+ - subs is the returned next epoch fluent values
429
+ - log includes all the auxiliary information about constraints
430
+ satisfied, errors, etc.
431
+
295
432
  constraint_func provides the option to compile nonlinear constraints:
296
433
 
297
434
  1. f(s, a) ?? g(s, a)
@@ -361,6 +498,10 @@ class JaxRDDLCompiler:
361
498
  reward, key, err = reward_fn(subs, model_params, key)
362
499
  errors |= err
363
500
 
501
+ # calculate fluent values
502
+ fluents = {name: values for (name, values) in subs.items()
503
+ if name not in rddl.non_fluents}
504
+
364
505
  # set the next state to the current state
365
506
  for (state, next_state) in rddl.next_state.items():
366
507
  subs[state] = subs[next_state]
@@ -383,8 +524,7 @@ class JaxRDDLCompiler:
383
524
 
384
525
  # prepare the return value
385
526
  log = {
386
- 'pvar': subs,
387
- 'action': actions,
527
+ 'fluents': fluents,
388
528
  'reward': reward,
389
529
  'error': errors,
390
530
  'precondition': precond_check,
@@ -395,7 +535,7 @@ class JaxRDDLCompiler:
395
535
  log['inequalities'] = inequalities
396
536
  log['equalities'] = equalities
397
537
 
398
- return log
538
+ return subs, log
399
539
 
400
540
  return _jax_wrapped_single_step
401
541
 
@@ -403,18 +543,28 @@ class JaxRDDLCompiler:
403
543
  n_steps: int,
404
544
  n_batch: int,
405
545
  check_constraints: bool=False,
406
- constraint_func: bool=False):
546
+ constraint_func: bool=False) -> Callable:
407
547
  '''Compiles the current RDDL model into a JAX transition function that
408
548
  samples trajectories with a fixed horizon from a policy.
409
549
 
410
- The signature of the policy function is (key, params, hyperparams,
411
- step, states), where:
550
+ The arguments of the returned function is:
551
+ - key is the PRNG key (used by a stochastic policy)
552
+ - policy_params is a pytree of trainable policy weights
553
+ - hyperparams is a pytree of (optional) fixed policy hyper-parameters
554
+ - subs is the dictionary of current fluent tensor values
555
+ - model_params is a dict of model hyperparameters.
556
+
557
+ The returned value of the returned function is:
558
+ - log is the dictionary of all trajectory information, including
559
+ constraints that were satisfied, errors, etc.
560
+
561
+ The arguments of the policy function is:
412
562
  - key is the PRNG key (used by a stochastic policy)
413
563
  - params is a pytree of trainable policy weights
414
564
  - hyperparams is a pytree of (optional) fixed policy hyper-parameters
415
565
  - step is the time index of the decision in the current rollout
416
566
  - states is a dict of tensors for the current observation.
417
-
567
+
418
568
  :param policy: a Jax compiled function for the policy as described above
419
569
  decision epoch, state dict, and an RNG key and returns an action dict
420
570
  :param n_steps: the rollout horizon
@@ -428,27 +578,32 @@ class JaxRDDLCompiler:
428
578
  rddl = self.rddl
429
579
  jax_step_fn = self.compile_transition(check_constraints, constraint_func)
430
580
 
581
+ # for POMDP only observ-fluents are assumed visible to the policy
582
+ if rddl.observ_fluents:
583
+ observed_vars = rddl.observ_fluents
584
+ else:
585
+ observed_vars = rddl.state_fluents
586
+
431
587
  # evaluate the step from the policy
432
588
  def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
433
589
  step, subs, model_params):
434
590
  states = {var: values
435
591
  for (var, values) in subs.items()
436
- if rddl.variable_types[var] == 'state-fluent'}
592
+ if var in observed_vars}
437
593
  actions = policy(key, policy_params, hyperparams, step, states)
438
594
  key, subkey = random.split(key)
439
- log = jax_step_fn(subkey, actions, subs, model_params)
440
- return log
595
+ subs, log = jax_step_fn(subkey, actions, subs, model_params)
596
+ return subs, log
441
597
 
442
598
  # do a batched step update from the policy
443
599
  def _jax_wrapped_batched_step_policy(carry, step):
444
600
  key, policy_params, hyperparams, subs, model_params = carry
445
601
  key, *subkeys = random.split(key, num=1 + n_batch)
446
602
  keys = jnp.asarray(subkeys)
447
- log = jax.vmap(
603
+ subs, log = jax.vmap(
448
604
  _jax_wrapped_single_step_policy,
449
605
  in_axes=(0, None, None, None, 0, None)
450
606
  )(keys, policy_params, hyperparams, step, subs, model_params)
451
- subs = log['pvar']
452
607
  carry = (key, policy_params, hyperparams, subs, model_params)
453
608
  return carry, log
454
609
 
@@ -467,7 +622,7 @@ class JaxRDDLCompiler:
467
622
  # error checks
468
623
  # ===========================================================================
469
624
 
470
- def print_jax(self) -> Dict[str, object]:
625
+ def print_jax(self) -> Dict[str, Any]:
471
626
  '''Returns a dictionary containing the string representations of all
472
627
  Jax compiled expressions from the RDDL file.
473
628
  '''
@@ -564,7 +719,7 @@ class JaxRDDLCompiler:
564
719
  }
565
720
 
566
721
  @staticmethod
567
- def get_error_codes(error):
722
+ def get_error_codes(error: int) -> List[int]:
568
723
  '''Given a compacted integer error flag from the execution of Jax, and
569
724
  decomposes it into individual error codes.
570
725
  '''
@@ -573,7 +728,7 @@ class JaxRDDLCompiler:
573
728
  return errors
574
729
 
575
730
  @staticmethod
576
- def get_error_messages(error):
731
+ def get_error_messages(error: int) -> List[str]:
577
732
  '''Given a compacted integer error flag from the execution of Jax, and
578
733
  decomposes it into error strings.
579
734
  '''
@@ -586,28 +741,40 @@ class JaxRDDLCompiler:
586
741
  # ===========================================================================
587
742
 
588
743
  def _unwrap(self, op, expr_id, info):
589
- sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
590
744
  jax_op, name = op, None
745
+ model_params, relaxed_list = info
591
746
  if isinstance(op, tuple):
592
747
  jax_op, param = op
593
748
  if param is not None:
594
749
  tags, values = param
750
+ sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
595
751
  if isinstance(tags, tuple):
596
752
  name = sep.join(tags)
597
753
  else:
598
754
  name = str(tags)
599
755
  name = f'{name}{sep}{expr_id}'
600
- if name in info:
601
- raise Exception(f'Model parameter {name} is already defined.')
602
- info[name] = values
756
+ if name in model_params:
757
+ raise RuntimeError(
758
+ f'Internal error: model parameter {name} is already defined.')
759
+ model_params[name] = (values, tags, expr_id, jax_op.__name__)
760
+ relaxed_list.append((param, expr_id, jax_op.__name__))
603
761
  return jax_op, name
604
762
 
605
- def get_ids_of_parameterized_expressions(self) -> List[int]:
606
- '''Returns a list of expression IDs that have tuning parameters.'''
607
- sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
608
- ids = [int(key.split(sep)[-1]) for key in self.model_params]
609
- return ids
610
-
763
+ def summarize_model_relaxations(self) -> str:
764
+ '''Returns a string of information about model relaxations in the
765
+ compiled model.'''
766
+ occurence_by_type = {}
767
+ for (_, expr_id, jax_op) in self.relaxations:
768
+ etype = self.traced.lookup(expr_id).etype
769
+ source = f'{etype[1]} ({etype[0]})'
770
+ sub = f'{source:<30} --> {jax_op}'
771
+ occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
772
+ col = "{:<80} {:<10}\n"
773
+ table = col.format('Substitution', 'Count')
774
+ for (sub, occurs) in occurence_by_type.items():
775
+ table += col.format(sub, occurs)
776
+ return table
777
+
611
778
  # ===========================================================================
612
779
  # expression compilation
613
780
  # ===========================================================================
@@ -640,7 +807,8 @@ class JaxRDDLCompiler:
640
807
  raise RDDLNotImplementedError(
641
808
  f'Internal error: expression type {expr} is not supported.\n' +
642
809
  print_stack_trace(expr))
643
-
810
+
811
+ # force type cast of tensor as required by caller
644
812
  if dtype is not None:
645
813
  jax_expr = self._jax_cast(jax_expr, dtype)
646
814
 
@@ -660,6 +828,17 @@ class JaxRDDLCompiler:
660
828
 
661
829
  return _jax_wrapped_cast
662
830
 
831
+ def _fix_dtype(self, value):
832
+ dtype = jnp.atleast_1d(value).dtype
833
+ if jnp.issubdtype(dtype, jnp.integer):
834
+ return self.INT
835
+ elif jnp.issubdtype(dtype, jnp.floating):
836
+ return self.REAL
837
+ elif jnp.issubdtype(dtype, jnp.bool_) or jnp.issubdtype(dtype, bool):
838
+ return bool
839
+ else:
840
+ raise TypeError(f'Invalid type {dtype} of {value}.')
841
+
663
842
  # ===========================================================================
664
843
  # leaves
665
844
  # ===========================================================================
@@ -669,7 +848,7 @@ class JaxRDDLCompiler:
669
848
  cached_value = self.traced.cached_sim_info(expr)
670
849
 
671
850
  def _jax_wrapped_constant(x, params, key):
672
- sample = jnp.asarray(cached_value)
851
+ sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
673
852
  return sample, key, NORMAL
674
853
 
675
854
  return _jax_wrapped_constant
@@ -693,7 +872,7 @@ class JaxRDDLCompiler:
693
872
  cached_value = cached_info
694
873
 
695
874
  def _jax_wrapped_object(x, params, key):
696
- sample = jnp.asarray(cached_value)
875
+ sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
697
876
  return sample, key, NORMAL
698
877
 
699
878
  return _jax_wrapped_object
@@ -702,7 +881,8 @@ class JaxRDDLCompiler:
702
881
  elif cached_info is None:
703
882
 
704
883
  def _jax_wrapped_pvar_scalar(x, params, key):
705
- sample = jnp.asarray(x[var])
884
+ value = x[var]
885
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
706
886
  return sample, key, NORMAL
707
887
 
708
888
  return _jax_wrapped_pvar_scalar
@@ -721,7 +901,8 @@ class JaxRDDLCompiler:
721
901
 
722
902
  def _jax_wrapped_pvar_tensor_nested(x, params, key):
723
903
  error = NORMAL
724
- sample = jnp.asarray(x[var])
904
+ value = x[var]
905
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
725
906
  new_slices = [None] * len(jax_nested_expr)
726
907
  for (i, jax_expr) in enumerate(jax_nested_expr):
727
908
  new_slices[i], key, err = jax_expr(x, params, key)
@@ -736,7 +917,8 @@ class JaxRDDLCompiler:
736
917
  else:
737
918
 
738
919
  def _jax_wrapped_pvar_tensor_non_nested(x, params, key):
739
- sample = jnp.asarray(x[var])
920
+ value = x[var]
921
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
740
922
  if slices:
741
923
  sample = sample[slices]
742
924
  if axis:
@@ -795,16 +977,23 @@ class JaxRDDLCompiler:
795
977
 
796
978
  def _jax_arithmetic(self, expr, info):
797
979
  _, op = expr.etype
798
- valid_ops = self.ARITHMETIC_OPS
980
+
981
+ # if expression is non-fluent, always use the exact operation
982
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
983
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
984
+ negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
985
+ else:
986
+ valid_ops = self.ARITHMETIC_OPS
987
+ negative_op = self.NEGATIVE
799
988
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
800
-
989
+
990
+ # recursively compile arguments
801
991
  args = expr.args
802
992
  n = len(args)
803
-
804
993
  if n == 1 and op == '-':
805
994
  arg, = args
806
995
  jax_expr = self._jax(arg, info)
807
- jax_op, jax_param = self._unwrap(self.NEGATIVE, expr.id, info)
996
+ jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
808
997
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
809
998
 
810
999
  elif n == 2:
@@ -819,29 +1008,42 @@ class JaxRDDLCompiler:
819
1008
 
820
1009
  def _jax_relational(self, expr, info):
821
1010
  _, op = expr.etype
822
- valid_ops = self.RELATIONAL_OPS
1011
+
1012
+ # if expression is non-fluent, always use the exact operation
1013
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1014
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
1015
+ else:
1016
+ valid_ops = self.RELATIONAL_OPS
823
1017
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
824
- JaxRDDLCompiler._check_num_args(expr, 2)
1018
+ jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
825
1019
 
1020
+ # recursively compile arguments
1021
+ JaxRDDLCompiler._check_num_args(expr, 2)
826
1022
  lhs, rhs = expr.args
827
1023
  jax_lhs = self._jax(lhs, info)
828
1024
  jax_rhs = self._jax(rhs, info)
829
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
830
1025
  return self._jax_binary(
831
1026
  jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
832
1027
 
833
1028
  def _jax_logical(self, expr, info):
834
1029
  _, op = expr.etype
835
- valid_ops = self.LOGICAL_OPS
1030
+
1031
+ # if expression is non-fluent, always use the exact operation
1032
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1033
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
1034
+ logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
1035
+ else:
1036
+ valid_ops = self.LOGICAL_OPS
1037
+ logical_not_op = self.LOGICAL_NOT
836
1038
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
837
1039
 
1040
+ # recursively compile arguments
838
1041
  args = expr.args
839
- n = len(args)
840
-
1042
+ n = len(args)
841
1043
  if n == 1 and op == '~':
842
1044
  arg, = args
843
1045
  jax_expr = self._jax(arg, info)
844
- jax_op, jax_param = self._unwrap(self.LOGICAL_NOT, expr.id, info)
1046
+ jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
845
1047
  return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
846
1048
 
847
1049
  elif n == 2:
@@ -856,17 +1058,21 @@ class JaxRDDLCompiler:
856
1058
 
857
1059
  def _jax_aggregation(self, expr, info):
858
1060
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
859
-
860
1061
  _, op = expr.etype
861
- valid_ops = self.AGGREGATION_OPS
1062
+
1063
+ # if expression is non-fluent, always use the exact operation
1064
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1065
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
1066
+ else:
1067
+ valid_ops = self.AGGREGATION_OPS
862
1068
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
863
- is_floating = op not in self.AGGREGATION_BOOL
1069
+ jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
864
1070
 
1071
+ # recursively compile arguments
1072
+ is_floating = op not in self.AGGREGATION_BOOL
865
1073
  * _, arg = expr.args
866
- _, axes = self.traced.cached_sim_info(expr)
867
-
1074
+ _, axes = self.traced.cached_sim_info(expr)
868
1075
  jax_expr = self._jax(arg, info)
869
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
870
1076
 
871
1077
  def _jax_wrapped_aggregation(x, params, key):
872
1078
  sample, key, err = jax_expr(x, params, key)
@@ -884,21 +1090,28 @@ class JaxRDDLCompiler:
884
1090
  def _jax_functional(self, expr, info):
885
1091
  _, op = expr.etype
886
1092
 
887
- # unary function
888
- if op in self.KNOWN_UNARY:
1093
+ # if expression is non-fluent, always use the exact operation
1094
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1095
+ unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
1096
+ binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
1097
+ else:
1098
+ unary_ops = self.KNOWN_UNARY
1099
+ binary_ops = self.KNOWN_BINARY
1100
+
1101
+ # recursively compile arguments
1102
+ if op in unary_ops:
889
1103
  JaxRDDLCompiler._check_num_args(expr, 1)
890
1104
  arg, = expr.args
891
1105
  jax_expr = self._jax(arg, info)
892
- jax_op, jax_param = self._unwrap(self.KNOWN_UNARY[op], expr.id, info)
1106
+ jax_op, jax_param = self._unwrap(unary_ops[op], expr.id, info)
893
1107
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
894
1108
 
895
- # binary function
896
- elif op in self.KNOWN_BINARY:
1109
+ elif op in binary_ops:
897
1110
  JaxRDDLCompiler._check_num_args(expr, 2)
898
1111
  lhs, rhs = expr.args
899
1112
  jax_lhs = self._jax(lhs, info)
900
1113
  jax_rhs = self._jax(rhs, info)
901
- jax_op, jax_param = self._unwrap(self.KNOWN_BINARY[op], expr.id, info)
1114
+ jax_op, jax_param = self._unwrap(binary_ops[op], expr.id, info)
902
1115
  return self._jax_binary(
903
1116
  jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
904
1117
 
@@ -921,19 +1134,19 @@ class JaxRDDLCompiler:
921
1134
  f'Control operator {op} is not supported.\n' +
922
1135
  print_stack_trace(expr))
923
1136
 
924
- def _jax_if_helper(self):
925
-
926
- def _jax_wrapped_if_calc_exact(c, a, b, param):
927
- return jnp.where(c, a, b)
928
-
929
- return _jax_wrapped_if_calc_exact
930
-
931
1137
  def _jax_if(self, expr, info):
932
1138
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
933
1139
  JaxRDDLCompiler._check_num_args(expr, 3)
934
- jax_if, jax_param = self._unwrap(self._jax_if_helper(), expr.id, info)
1140
+ pred, if_true, if_false = expr.args
1141
+
1142
+ # if predicate is non-fluent, always use the exact operation
1143
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1144
+ if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
1145
+ else:
1146
+ if_op = self.IF_HELPER
1147
+ jax_if, jax_param = self._unwrap(if_op, expr.id, info)
935
1148
 
936
- pred, if_true, if_false = expr.args
1149
+ # recursively compile arguments
937
1150
  jax_pred = self._jax(pred, info)
938
1151
  jax_true = self._jax(if_true, info)
939
1152
  jax_false = self._jax(if_false, info)
@@ -951,23 +1164,20 @@ class JaxRDDLCompiler:
951
1164
 
952
1165
  return _jax_wrapped_if_then_else
953
1166
 
954
- def _jax_switch_helper(self):
955
-
956
- def _jax_wrapped_switch_calc_exact(pred, cases, param):
957
- pred = pred[jnp.newaxis, ...]
958
- sample = jnp.take_along_axis(cases, pred, axis=0)
959
- assert sample.shape[0] == 1
960
- return sample[0, ...]
961
-
962
- return _jax_wrapped_switch_calc_exact
963
-
964
1167
  def _jax_switch(self, expr, info):
965
- pred, *_ = expr.args
1168
+
1169
+ # if expression is non-fluent, always use the exact operation
1170
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1171
+ switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1172
+ else:
1173
+ switch_op = self.SWITCH_HELPER
1174
+ jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1175
+
1176
+ # recursively compile predicate
1177
+ pred, *_ = expr.args
966
1178
  jax_pred = self._jax(pred, info)
967
- jax_switch, jax_param = self._unwrap(
968
- self._jax_switch_helper(), expr.id, info)
969
1179
 
970
- # wrap cases as JAX expressions
1180
+ # recursively compile cases
971
1181
  cases, default = self.traced.cached_sim_info(expr)
972
1182
  jax_default = None if default is None else self._jax(default, info)
973
1183
  jax_cases = [(jax_default if _case is None else self._jax(_case, info))
@@ -983,7 +1193,8 @@ class JaxRDDLCompiler:
983
1193
  for (i, jax_case) in enumerate(jax_cases):
984
1194
  sample_cases[i], key, err_case = jax_case(x, params, key)
985
1195
  err |= err_case
986
- sample_cases = jnp.asarray(sample_cases)
1196
+ sample_cases = jnp.asarray(
1197
+ sample_cases, dtype=self._fix_dtype(sample_cases))
987
1198
 
988
1199
  # predicate (enum) is an integer - use it to extract from case array
989
1200
  param = params.get(jax_param, None)
@@ -1179,30 +1390,28 @@ class JaxRDDLCompiler:
1179
1390
  scale, key, err2 = jax_scale(x, params, key)
1180
1391
  key, subkey = random.split(key)
1181
1392
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1182
- sample = scale * jnp.power(-jnp.log1p(-U), 1.0 / shape)
1393
+ sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
1183
1394
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1184
1395
  err = err1 | err2 | (out_of_bounds * ERR)
1185
1396
  return sample, key, err
1186
1397
 
1187
1398
  return _jax_wrapped_distribution_weibull
1188
1399
 
1189
- def _jax_bernoulli_helper(self):
1190
-
1191
- def _jax_wrapped_calc_bernoulli_exact(key, prob, param):
1192
- return random.bernoulli(key, prob)
1193
-
1194
- return _jax_wrapped_calc_bernoulli_exact
1195
-
1196
1400
  def _jax_bernoulli(self, expr, info):
1197
1401
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1198
1402
  JaxRDDLCompiler._check_num_args(expr, 1)
1199
- jax_bern, jax_param = self._unwrap(
1200
- self._jax_bernoulli_helper(), expr.id, info)
1201
-
1202
1403
  arg_prob, = expr.args
1404
+
1405
+ # if probability is non-fluent, always use the exact operation
1406
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1407
+ bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
1408
+ else:
1409
+ bern_op = self.BERNOULLI_HELPER
1410
+ jax_bern, jax_param = self._unwrap(bern_op, expr.id, info)
1411
+
1412
+ # recursively compile arguments
1203
1413
  jax_prob = self._jax(arg_prob, info)
1204
1414
 
1205
- # uses the implicit JAX subroutine
1206
1415
  def _jax_wrapped_distribution_bernoulli(x, params, key):
1207
1416
  prob, key, err = jax_prob(x, params, key)
1208
1417
  key, subkey = random.split(key)
@@ -1266,8 +1475,8 @@ class JaxRDDLCompiler:
1266
1475
  def _jax_wrapped_distribution_binomial(x, params, key):
1267
1476
  trials, key, err2 = jax_trials(x, params, key)
1268
1477
  prob, key, err1 = jax_prob(x, params, key)
1269
- trials = jnp.asarray(trials, self.REAL)
1270
- prob = jnp.asarray(prob, self.REAL)
1478
+ trials = jnp.asarray(trials, dtype=self.REAL)
1479
+ prob = jnp.asarray(prob, dtype=self.REAL)
1271
1480
  key, subkey = random.split(key)
1272
1481
  dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
1273
1482
  sample = dist.sample(seed=subkey).astype(self.INT)
@@ -1290,11 +1499,10 @@ class JaxRDDLCompiler:
1290
1499
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1291
1500
  trials, key, err2 = jax_trials(x, params, key)
1292
1501
  prob, key, err1 = jax_prob(x, params, key)
1293
- trials = jnp.asarray(trials, self.REAL)
1294
- prob = jnp.asarray(prob, self.REAL)
1502
+ trials = jnp.asarray(trials, dtype=self.REAL)
1503
+ prob = jnp.asarray(prob, dtype=self.REAL)
1295
1504
  key, subkey = random.split(key)
1296
- dist = tfp.distributions.NegativeBinomial(
1297
- total_count=trials, probs=prob)
1505
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1298
1506
  sample = dist.sample(seed=subkey).astype(self.INT)
1299
1507
  out_of_bounds = jnp.logical_not(jnp.all(
1300
1508
  (prob >= 0) & (prob <= 1) & (trials > 0)))
@@ -1316,7 +1524,7 @@ class JaxRDDLCompiler:
1316
1524
  shape, key, err1 = jax_shape(x, params, key)
1317
1525
  rate, key, err2 = jax_rate(x, params, key)
1318
1526
  key, subkey = random.split(key)
1319
- sample = random.beta(key=subkey, a=shape, b=rate)
1527
+ sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
1320
1528
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
1321
1529
  err = err1 | err2 | (out_of_bounds * ERR)
1322
1530
  return sample, key, err
@@ -1325,23 +1533,35 @@ class JaxRDDLCompiler:
1325
1533
 
1326
1534
  def _jax_geometric(self, expr, info):
1327
1535
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1328
- JaxRDDLCompiler._check_num_args(expr, 1)
1329
-
1536
+ JaxRDDLCompiler._check_num_args(expr, 1)
1330
1537
  arg_prob, = expr.args
1331
1538
  jax_prob = self._jax(arg_prob, info)
1332
- floor_op, jax_param = self._unwrap(
1333
- self.KNOWN_UNARY['floor'], expr.id, info)
1334
1539
 
1335
- # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1336
- def _jax_wrapped_distribution_geometric(x, params, key):
1337
- prob, key, err = jax_prob(x, params, key)
1338
- key, subkey = random.split(key)
1339
- U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1340
- param = params.get(jax_param, None)
1341
- sample = floor_op(jnp.log1p(-U) / jnp.log1p(-prob), param) + 1
1342
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1343
- err |= (out_of_bounds * ERR)
1344
- return sample, key, err
1540
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1541
+
1542
+ # prob is non-fluent: do not reparameterize
1543
+ def _jax_wrapped_distribution_geometric(x, params, key):
1544
+ prob, key, err = jax_prob(x, params, key)
1545
+ key, subkey = random.split(key)
1546
+ sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1547
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1548
+ err |= (out_of_bounds * ERR)
1549
+ return sample, key, err
1550
+
1551
+ else:
1552
+ floor_op, jax_param = self._unwrap(
1553
+ self.KNOWN_UNARY['floor'], expr.id, info)
1554
+
1555
+ # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1556
+ def _jax_wrapped_distribution_geometric(x, params, key):
1557
+ prob, key, err = jax_prob(x, params, key)
1558
+ key, subkey = random.split(key)
1559
+ U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1560
+ param = params.get(jax_param, None)
1561
+ sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
1562
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1563
+ err |= (out_of_bounds * ERR)
1564
+ return sample, key, err
1345
1565
 
1346
1566
  return _jax_wrapped_distribution_geometric
1347
1567
 
@@ -1359,7 +1579,7 @@ class JaxRDDLCompiler:
1359
1579
  shape, key, err1 = jax_shape(x, params, key)
1360
1580
  scale, key, err2 = jax_scale(x, params, key)
1361
1581
  key, subkey = random.split(key)
1362
- sample = scale * random.pareto(key=subkey, b=shape)
1582
+ sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
1363
1583
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1364
1584
  err = err1 | err2 | (out_of_bounds * ERR)
1365
1585
  return sample, key, err
@@ -1377,7 +1597,8 @@ class JaxRDDLCompiler:
1377
1597
  def _jax_wrapped_distribution_t(x, params, key):
1378
1598
  df, key, err = jax_df(x, params, key)
1379
1599
  key, subkey = random.split(key)
1380
- sample = random.t(key=subkey, df=df, shape=jnp.shape(df))
1600
+ sample = random.t(
1601
+ key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1381
1602
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1382
1603
  err |= (out_of_bounds * ERR)
1383
1604
  return sample, key, err
@@ -1464,7 +1685,7 @@ class JaxRDDLCompiler:
1464
1685
  scale, key, err2 = jax_scale(x, params, key)
1465
1686
  key, subkey = random.split(key)
1466
1687
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1467
- sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
1688
+ sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
1468
1689
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1469
1690
  err = err1 | err2 | (out_of_bounds * ERR)
1470
1691
  return sample, key, err
@@ -1516,25 +1737,21 @@ class JaxRDDLCompiler:
1516
1737
  # random variables with enum support
1517
1738
  # ===========================================================================
1518
1739
 
1519
- def _jax_discrete_helper(self):
1520
-
1521
- def _jax_wrapped_discrete_calc_exact(key, prob, param):
1522
- logits = jnp.log(prob)
1523
- sample = random.categorical(key=key, logits=logits, axis=-1)
1524
- out_of_bounds = jnp.logical_not(jnp.logical_and(
1525
- jnp.all(prob >= 0),
1526
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1527
- return sample, out_of_bounds
1528
-
1529
- return _jax_wrapped_discrete_calc_exact
1530
-
1531
1740
  def _jax_discrete(self, expr, info, unnorm):
1532
1741
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
1533
1742
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1534
- jax_discrete, jax_param = self._unwrap(
1535
- self._jax_discrete_helper(), expr.id, info)
1536
-
1537
1743
  ordered_args = self.traced.cached_sim_info(expr)
1744
+
1745
+ # if all probabilities are non-fluent, then always sample exact
1746
+ has_fluent_arg = any(self.traced.cached_is_fluent(arg)
1747
+ for arg in ordered_args)
1748
+ if self.compile_non_fluent_exact and not has_fluent_arg:
1749
+ discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1750
+ else:
1751
+ discrete_op = self.DISCRETE_HELPER
1752
+ jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1753
+
1754
+ # compile probability expressions
1538
1755
  jax_probs = [self._jax(arg, info) for arg in ordered_args]
1539
1756
 
1540
1757
  def _jax_wrapped_distribution_discrete(x, params, key):
@@ -1561,12 +1778,18 @@ class JaxRDDLCompiler:
1561
1778
 
1562
1779
  def _jax_discrete_pvar(self, expr, info, unnorm):
1563
1780
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1564
- JaxRDDLCompiler._check_num_args(expr, 1)
1565
- jax_discrete, jax_param = self._unwrap(
1566
- self._jax_discrete_helper(), expr.id, info)
1567
-
1781
+ JaxRDDLCompiler._check_num_args(expr, 2)
1568
1782
  _, args = expr.args
1569
1783
  arg, = args
1784
+
1785
+ # if probabilities are non-fluent, then always sample exact
1786
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
1787
+ discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1788
+ else:
1789
+ discrete_op = self.DISCRETE_HELPER
1790
+ jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1791
+
1792
+ # compile probability function
1570
1793
  jax_probs = self._jax(arg, info)
1571
1794
 
1572
1795
  def _jax_wrapped_distribution_discrete_pvar(x, params, key):
@@ -1687,7 +1910,7 @@ class JaxRDDLCompiler:
1687
1910
  out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
1688
1911
  error |= (out_of_bounds * ERR)
1689
1912
  key, subkey = random.split(key)
1690
- Gamma = random.gamma(key=subkey, a=alpha)
1913
+ Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
1691
1914
  sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
1692
1915
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1693
1916
  return sample, key, error
@@ -1706,8 +1929,8 @@ class JaxRDDLCompiler:
1706
1929
  def _jax_wrapped_distribution_multinomial(x, params, key):
1707
1930
  trials, key, err1 = jax_trials(x, params, key)
1708
1931
  prob, key, err2 = jax_prob(x, params, key)
1709
- trials = jnp.asarray(trials, self.REAL)
1710
- prob = jnp.asarray(prob, self.REAL)
1932
+ trials = jnp.asarray(trials, dtype=self.REAL)
1933
+ prob = jnp.asarray(prob, dtype=self.REAL)
1711
1934
  key, subkey = random.split(key)
1712
1935
  dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
1713
1936
  sample = dist.sample(seed=subkey).astype(self.INT)