pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. pyRDDLGym_jax/core/compiler.py +445 -221
  2. pyRDDLGym_jax/core/logic.py +129 -62
  3. pyRDDLGym_jax/core/planner.py +699 -332
  4. pyRDDLGym_jax/core/simulator.py +5 -7
  5. pyRDDLGym_jax/core/tuning.py +23 -12
  6. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  7. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
  8. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  9. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
  10. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  11. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  12. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  13. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  14. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  15. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  16. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  17. pyRDDLGym_jax/examples/run_gym.py +1 -2
  18. pyRDDLGym_jax/examples/run_plan.py +7 -0
  19. pyRDDLGym_jax/examples/run_tune.py +6 -0
  20. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
  21. pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
  22. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
  23. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  24. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  25. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  26. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  27. /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
  28. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  29. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  30. /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  34. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
  35. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import jax.numpy as jnp
4
4
  import jax.random as random
5
5
  import jax.scipy as scipy
6
6
  import traceback
7
- from typing import Callable, Dict, List
7
+ from typing import Any, Callable, Dict, List, Optional
8
8
 
9
9
  from pyRDDLGym.core.debug.exception import raise_warning
10
10
 
@@ -13,8 +13,9 @@ try:
13
13
  from tensorflow_probability.substrates import jax as tfp
14
14
  except Exception:
15
15
  raise_warning('Failed to import tensorflow-probability: '
16
- 'compilation of some complex distributions will not work.',
17
- 'red')
16
+ 'compilation of some complex distributions '
17
+ '(Binomial, Negative-Binomial, Multinomial) will fail. '
18
+ 'Please ensure this package is installed correctly.', 'red')
18
19
  traceback.print_exc()
19
20
  tfp = None
20
21
 
@@ -32,6 +33,98 @@ from pyRDDLGym.core.debug.logger import Logger
32
33
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
33
34
 
34
35
 
36
+ # ===========================================================================
37
+ # EXACT RDDL TO JAX COMPILATION RULES
38
+ # ===========================================================================
39
+
40
+ def _function_unary_exact_named(op, name):
41
+
42
+ def _jax_wrapped_unary_fn_exact(x, param):
43
+ return op(x)
44
+
45
+ return _jax_wrapped_unary_fn_exact
46
+
47
+
48
+ def _function_unary_exact_named_gamma():
49
+
50
+ def _jax_wrapped_unary_gamma_exact(x, param):
51
+ return jnp.exp(scipy.special.gammaln(x))
52
+
53
+ return _jax_wrapped_unary_gamma_exact
54
+
55
+
56
+ def _function_binary_exact_named(op, name):
57
+
58
+ def _jax_wrapped_binary_fn_exact(x, y, param):
59
+ return op(x, y)
60
+
61
+ return _jax_wrapped_binary_fn_exact
62
+
63
+
64
+ def _function_binary_exact_named_implies():
65
+
66
+ def _jax_wrapped_binary_implies_exact(x, y, param):
67
+ return jnp.logical_or(jnp.logical_not(x), y)
68
+
69
+ return _jax_wrapped_binary_implies_exact
70
+
71
+
72
+ def _function_binary_exact_named_log():
73
+
74
+ def _jax_wrapped_binary_log_exact(x, y, param):
75
+ return jnp.log(x) / jnp.log(y)
76
+
77
+ return _jax_wrapped_binary_log_exact
78
+
79
+
80
+ def _function_aggregation_exact_named(op, name):
81
+
82
+ def _jax_wrapped_aggregation_fn_exact(x, axis, param):
83
+ return op(x, axis=axis)
84
+
85
+ return _jax_wrapped_aggregation_fn_exact
86
+
87
+
88
+ def _function_if_exact_named():
89
+
90
+ def _jax_wrapped_if_exact(c, a, b, param):
91
+ return jnp.where(c, a, b)
92
+
93
+ return _jax_wrapped_if_exact
94
+
95
+
96
+ def _function_switch_exact_named():
97
+
98
+ def _jax_wrapped_switch_exact(pred, cases, param):
99
+ pred = pred[jnp.newaxis, ...]
100
+ sample = jnp.take_along_axis(cases, pred, axis=0)
101
+ assert sample.shape[0] == 1
102
+ return sample[0, ...]
103
+
104
+ return _jax_wrapped_switch_exact
105
+
106
+
107
+ def _function_bernoulli_exact_named():
108
+
109
+ def _jax_wrapped_bernoulli_exact(key, prob, param):
110
+ return random.bernoulli(key, prob)
111
+
112
+ return _jax_wrapped_bernoulli_exact
113
+
114
+
115
+ def _function_discrete_exact_named():
116
+
117
+ def _jax_wrapped_discrete_exact(key, prob, param):
118
+ logits = jnp.log(prob)
119
+ sample = random.categorical(key=key, logits=logits, axis=-1)
120
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
121
+ jnp.all(prob >= 0),
122
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
123
+ return sample, out_of_bounds
124
+
125
+ return _jax_wrapped_discrete_exact
126
+
127
+
35
128
  class JaxRDDLCompiler:
36
129
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
37
130
  All operations are identical to their numpy equivalents.
@@ -39,10 +132,97 @@ class JaxRDDLCompiler:
39
132
 
40
133
  MODEL_PARAM_TAG_SEPARATOR = '___'
41
134
 
135
+ # ===========================================================================
136
+ # EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
137
+ # ===========================================================================
138
+
139
+ EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
140
+
141
+ EXACT_RDDL_TO_JAX_ARITHMETIC = {
142
+ '+': _function_binary_exact_named(jnp.add, 'add'),
143
+ '-': _function_binary_exact_named(jnp.subtract, 'subtract'),
144
+ '*': _function_binary_exact_named(jnp.multiply, 'multiply'),
145
+ '/': _function_binary_exact_named(jnp.divide, 'divide')
146
+ }
147
+
148
+ EXACT_RDDL_TO_JAX_RELATIONAL = {
149
+ '>=': _function_binary_exact_named(jnp.greater_equal, 'greater_equal'),
150
+ '<=': _function_binary_exact_named(jnp.less_equal, 'less_equal'),
151
+ '<': _function_binary_exact_named(jnp.less, 'less'),
152
+ '>': _function_binary_exact_named(jnp.greater, 'greater'),
153
+ '==': _function_binary_exact_named(jnp.equal, 'equal'),
154
+ '~=': _function_binary_exact_named(jnp.not_equal, 'not_equal')
155
+ }
156
+
157
+ EXACT_RDDL_TO_JAX_LOGICAL = {
158
+ '^': _function_binary_exact_named(jnp.logical_and, 'and'),
159
+ '&': _function_binary_exact_named(jnp.logical_and, 'and'),
160
+ '|': _function_binary_exact_named(jnp.logical_or, 'or'),
161
+ '~': _function_binary_exact_named(jnp.logical_xor, 'xor'),
162
+ '=>': _function_binary_exact_named_implies(),
163
+ '<=>': _function_binary_exact_named(jnp.equal, 'iff')
164
+ }
165
+
166
+ EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
167
+
168
+ EXACT_RDDL_TO_JAX_AGGREGATION = {
169
+ 'sum': _function_aggregation_exact_named(jnp.sum, 'sum'),
170
+ 'avg': _function_aggregation_exact_named(jnp.mean, 'avg'),
171
+ 'prod': _function_aggregation_exact_named(jnp.prod, 'prod'),
172
+ 'minimum': _function_aggregation_exact_named(jnp.min, 'minimum'),
173
+ 'maximum': _function_aggregation_exact_named(jnp.max, 'maximum'),
174
+ 'forall': _function_aggregation_exact_named(jnp.all, 'forall'),
175
+ 'exists': _function_aggregation_exact_named(jnp.any, 'exists'),
176
+ 'argmin': _function_aggregation_exact_named(jnp.argmin, 'argmin'),
177
+ 'argmax': _function_aggregation_exact_named(jnp.argmax, 'argmax')
178
+ }
179
+
180
+ EXACT_RDDL_TO_JAX_UNARY = {
181
+ 'abs': _function_unary_exact_named(jnp.abs, 'abs'),
182
+ 'sgn': _function_unary_exact_named(jnp.sign, 'sgn'),
183
+ 'round': _function_unary_exact_named(jnp.round, 'round'),
184
+ 'floor': _function_unary_exact_named(jnp.floor, 'floor'),
185
+ 'ceil': _function_unary_exact_named(jnp.ceil, 'ceil'),
186
+ 'cos': _function_unary_exact_named(jnp.cos, 'cos'),
187
+ 'sin': _function_unary_exact_named(jnp.sin, 'sin'),
188
+ 'tan': _function_unary_exact_named(jnp.tan, 'tan'),
189
+ 'acos': _function_unary_exact_named(jnp.arccos, 'acos'),
190
+ 'asin': _function_unary_exact_named(jnp.arcsin, 'asin'),
191
+ 'atan': _function_unary_exact_named(jnp.arctan, 'atan'),
192
+ 'cosh': _function_unary_exact_named(jnp.cosh, 'cosh'),
193
+ 'sinh': _function_unary_exact_named(jnp.sinh, 'sinh'),
194
+ 'tanh': _function_unary_exact_named(jnp.tanh, 'tanh'),
195
+ 'exp': _function_unary_exact_named(jnp.exp, 'exp'),
196
+ 'ln': _function_unary_exact_named(jnp.log, 'ln'),
197
+ 'sqrt': _function_unary_exact_named(jnp.sqrt, 'sqrt'),
198
+ 'lngamma': _function_unary_exact_named(scipy.special.gammaln, 'lngamma'),
199
+ 'gamma': _function_unary_exact_named_gamma()
200
+ }
201
+
202
+ EXACT_RDDL_TO_JAX_BINARY = {
203
+ 'div': _function_binary_exact_named(jnp.floor_divide, 'div'),
204
+ 'mod': _function_binary_exact_named(jnp.mod, 'mod'),
205
+ 'fmod': _function_binary_exact_named(jnp.mod, 'fmod'),
206
+ 'min': _function_binary_exact_named(jnp.minimum, 'min'),
207
+ 'max': _function_binary_exact_named(jnp.maximum, 'max'),
208
+ 'pow': _function_binary_exact_named(jnp.power, 'pow'),
209
+ 'log': _function_binary_exact_named_log(),
210
+ 'hypot': _function_binary_exact_named(jnp.hypot, 'hypot'),
211
+ }
212
+
213
+ EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
214
+
215
+ EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
216
+
217
+ EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
218
+
219
+ EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
220
+
42
221
  def __init__(self, rddl: RDDLLiftedModel,
43
222
  allow_synchronous_state: bool=True,
44
- logger: Logger=None,
45
- use64bit: bool=False) -> None:
223
+ logger: Optional[Logger]=None,
224
+ use64bit: bool=False,
225
+ compile_non_fluent_exact: bool=True) -> None:
46
226
  '''Creates a new RDDL to Jax compiler.
47
227
 
48
228
  :param rddl: the RDDL model to compile into Jax
@@ -50,11 +230,14 @@ class JaxRDDLCompiler:
50
230
  on each other
51
231
  :param logger: to log information about compilation to file
52
232
  :param use64bit: whether to use 64 bit arithmetic
233
+ :param compile_non_fluent_exact: whether non-fluent expressions
234
+ are always compiled using exact JAX expressions.
53
235
  '''
54
236
  self.rddl = rddl
55
237
  self.logger = logger
56
238
  # jax.config.update('jax_log_compiles', True) # for testing ONLY
57
239
 
240
+ self.use64bit = use64bit
58
241
  if use64bit:
59
242
  self.INT = jnp.int64
60
243
  self.REAL = jnp.float64
@@ -62,6 +245,7 @@ class JaxRDDLCompiler:
62
245
  else:
63
246
  self.INT = jnp.int32
64
247
  self.REAL = jnp.float32
248
+ jax.config.update('jax_enable_x64', False)
65
249
  self.ONE = jnp.asarray(1, dtype=self.INT)
66
250
  self.JAX_TYPES = {
67
251
  'int': self.INT,
@@ -70,17 +254,16 @@ class JaxRDDLCompiler:
70
254
  }
71
255
 
72
256
  # compile initial values
73
- if self.logger is not None:
74
- self.logger.clear()
75
- initializer = RDDLValueInitializer(rddl, logger=self.logger)
257
+ initializer = RDDLValueInitializer(rddl)
76
258
  self.init_values = initializer.initialize()
77
259
 
78
260
  # compute dependency graph for CPFs and sort them by evaluation order
79
- sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state, logger=self.logger)
261
+ sorter = RDDLLevelAnalysis(
262
+ rddl, allow_synchronous_state=allow_synchronous_state)
80
263
  self.levels = sorter.compute_levels()
81
264
 
82
265
  # trace expressions to cache information to be used later
83
- tracer = RDDLObjectsTracer(rddl, logger=self.logger, cpf_levels=self.levels)
266
+ tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
84
267
  self.traced = tracer.trace()
85
268
 
86
269
  # extract the box constraints on actions
@@ -92,92 +275,42 @@ class JaxRDDLCompiler:
92
275
  constraints = RDDLConstraints(simulator, vectorized=True)
93
276
  self.constraints = constraints
94
277
 
95
- # basic operations
96
- self.NEGATIVE = lambda x, param: jnp.negative(x)
97
- self.ARITHMETIC_OPS = {
98
- '+': lambda x, y, param: jnp.add(x, y),
99
- '-': lambda x, y, param: jnp.subtract(x, y),
100
- '*': lambda x, y, param: jnp.multiply(x, y),
101
- '/': lambda x, y, param: jnp.divide(x, y)
102
- }
103
- self.RELATIONAL_OPS = {
104
- '>=': lambda x, y, param: jnp.greater_equal(x, y),
105
- '<=': lambda x, y, param: jnp.less_equal(x, y),
106
- '<': lambda x, y, param: jnp.less(x, y),
107
- '>': lambda x, y, param: jnp.greater(x, y),
108
- '==': lambda x, y, param: jnp.equal(x, y),
109
- '~=': lambda x, y, param: jnp.not_equal(x, y)
110
- }
111
- self.LOGICAL_NOT = lambda x, param: jnp.logical_not(x)
112
- self.LOGICAL_OPS = {
113
- '^': lambda x, y, param: jnp.logical_and(x, y),
114
- '&': lambda x, y, param: jnp.logical_and(x, y),
115
- '|': lambda x, y, param: jnp.logical_or(x, y),
116
- '~': lambda x, y, param: jnp.logical_xor(x, y),
117
- '=>': lambda x, y, param: jnp.logical_or(jnp.logical_not(x), y),
118
- '<=>': lambda x, y, param: jnp.equal(x, y)
119
- }
120
- self.AGGREGATION_OPS = {
121
- 'sum': lambda x, axis, param: jnp.sum(x, axis=axis),
122
- 'avg': lambda x, axis, param: jnp.mean(x, axis=axis),
123
- 'prod': lambda x, axis, param: jnp.prod(x, axis=axis),
124
- 'minimum': lambda x, axis, param: jnp.min(x, axis=axis),
125
- 'maximum': lambda x, axis, param: jnp.max(x, axis=axis),
126
- 'forall': lambda x, axis, param: jnp.all(x, axis=axis),
127
- 'exists': lambda x, axis, param: jnp.any(x, axis=axis),
128
- 'argmin': lambda x, axis, param: jnp.argmin(x, axis=axis),
129
- 'argmax': lambda x, axis, param: jnp.argmax(x, axis=axis)
130
- }
278
+ # basic operations - these can be override in subclasses
279
+ self.compile_non_fluent_exact = compile_non_fluent_exact
280
+ self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
281
+ self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
282
+ self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
283
+ self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
284
+ self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
285
+ self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
131
286
  self.AGGREGATION_BOOL = {'forall', 'exists'}
132
- self.KNOWN_UNARY = {
133
- 'abs': lambda x, param: jnp.abs(x),
134
- 'sgn': lambda x, param: jnp.sign(x),
135
- 'round': lambda x, param: jnp.round(x),
136
- 'floor': lambda x, param: jnp.floor(x),
137
- 'ceil': lambda x, param: jnp.ceil(x),
138
- 'cos': lambda x, param: jnp.cos(x),
139
- 'sin': lambda x, param: jnp.sin(x),
140
- 'tan': lambda x, param: jnp.tan(x),
141
- 'acos': lambda x, param: jnp.arccos(x),
142
- 'asin': lambda x, param: jnp.arcsin(x),
143
- 'atan': lambda x, param: jnp.arctan(x),
144
- 'cosh': lambda x, param: jnp.cosh(x),
145
- 'sinh': lambda x, param: jnp.sinh(x),
146
- 'tanh': lambda x, param: jnp.tanh(x),
147
- 'exp': lambda x, param: jnp.exp(x),
148
- 'ln': lambda x, param: jnp.log(x),
149
- 'sqrt': lambda x, param: jnp.sqrt(x),
150
- 'lngamma': lambda x, param: scipy.special.gammaln(x),
151
- 'gamma': lambda x, param: jnp.exp(scipy.special.gammaln(x))
152
- }
153
- self.KNOWN_BINARY = {
154
- 'div': lambda x, y, param: jnp.floor_divide(x, y),
155
- 'mod': lambda x, y, param: jnp.mod(x, y),
156
- 'fmod': lambda x, y, param: jnp.mod(x, y),
157
- 'min': lambda x, y, param: jnp.minimum(x, y),
158
- 'max': lambda x, y, param: jnp.maximum(x, y),
159
- 'pow': lambda x, y, param: jnp.power(x, y),
160
- 'log': lambda x, y, param: jnp.log(x) / jnp.log(y),
161
- 'hypot': lambda x, y, param: jnp.hypot(x, y)
162
- }
163
-
287
+ self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
288
+ self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
289
+ self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
290
+ self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
291
+ self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
292
+ self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
293
+
164
294
  # ===========================================================================
165
295
  # main compilation subroutines
166
296
  # ===========================================================================
167
297
 
168
- def compile(self, log_jax_expr: bool=False) -> None:
298
+ def compile(self, log_jax_expr: bool=False, heading: str='') -> None:
169
299
  '''Compiles the current RDDL into Jax expressions.
170
300
 
171
301
  :param log_jax_expr: whether to pretty-print the compiled Jax functions
172
302
  to the log file
303
+ :param heading: the heading to print before compilation information
173
304
  '''
174
- info = {}
305
+ info = ({}, [])
175
306
  self.invariants = self._compile_constraints(self.rddl.invariants, info)
176
307
  self.preconditions = self._compile_constraints(self.rddl.preconditions, info)
177
308
  self.terminations = self._compile_constraints(self.rddl.terminations, info)
178
309
  self.cpfs = self._compile_cpfs(info)
179
310
  self.reward = self._compile_reward(info)
180
- self.model_params = info
311
+ self.model_params = {key: value
312
+ for (key, (value, *_)) in info[0].items()}
313
+ self.relaxations = info[1]
181
314
 
182
315
  if log_jax_expr and self.logger is not None:
183
316
  printed = self.print_jax()
@@ -189,6 +322,7 @@ class JaxRDDLCompiler:
189
322
  printed_terminals = '\n\n'.join(v for v in printed['terminations'])
190
323
  printed_params = '\n'.join(f'{k}: {v}' for (k, v) in info.items())
191
324
  message = (
325
+ f'[info] {heading}\n'
192
326
  f'[info] compiled JAX CPFs:\n\n'
193
327
  f'{printed_cpfs}\n\n'
194
328
  f'[info] compiled JAX reward:\n\n'
@@ -281,17 +415,21 @@ class JaxRDDLCompiler:
281
415
  return jax_inequalities, jax_equalities
282
416
 
283
417
  def compile_transition(self, check_constraints: bool=False,
284
- constraint_func: bool=False):
418
+ constraint_func: bool=False) -> Callable:
285
419
  '''Compiles the current RDDL model into a JAX transition function that
286
420
  samples the next state.
287
421
 
288
- The signature of the returned function is (key, actions, subs,
289
- model_params), where:
422
+ The arguments of the returned function is:
290
423
  - key is the PRNG key
291
424
  - actions is the dict of action tensors
292
425
  - subs is the dict of current pvar value tensors
293
426
  - model_params is a dict of parameters for the relaxed model.
294
-
427
+
428
+ The returned value of the function is:
429
+ - subs is the returned next epoch fluent values
430
+ - log includes all the auxiliary information about constraints
431
+ satisfied, errors, etc.
432
+
295
433
  constraint_func provides the option to compile nonlinear constraints:
296
434
 
297
435
  1. f(s, a) ?? g(s, a)
@@ -361,6 +499,10 @@ class JaxRDDLCompiler:
361
499
  reward, key, err = reward_fn(subs, model_params, key)
362
500
  errors |= err
363
501
 
502
+ # calculate fluent values
503
+ fluents = {name: values for (name, values) in subs.items()
504
+ if name not in rddl.non_fluents}
505
+
364
506
  # set the next state to the current state
365
507
  for (state, next_state) in rddl.next_state.items():
366
508
  subs[state] = subs[next_state]
@@ -383,8 +525,7 @@ class JaxRDDLCompiler:
383
525
 
384
526
  # prepare the return value
385
527
  log = {
386
- 'pvar': subs,
387
- 'action': actions,
528
+ 'fluents': fluents,
388
529
  'reward': reward,
389
530
  'error': errors,
390
531
  'precondition': precond_check,
@@ -395,7 +536,7 @@ class JaxRDDLCompiler:
395
536
  log['inequalities'] = inequalities
396
537
  log['equalities'] = equalities
397
538
 
398
- return log
539
+ return subs, log
399
540
 
400
541
  return _jax_wrapped_single_step
401
542
 
@@ -403,18 +544,28 @@ class JaxRDDLCompiler:
403
544
  n_steps: int,
404
545
  n_batch: int,
405
546
  check_constraints: bool=False,
406
- constraint_func: bool=False):
547
+ constraint_func: bool=False) -> Callable:
407
548
  '''Compiles the current RDDL model into a JAX transition function that
408
549
  samples trajectories with a fixed horizon from a policy.
409
550
 
410
- The signature of the policy function is (key, params, hyperparams,
411
- step, states), where:
551
+ The arguments of the returned function is:
552
+ - key is the PRNG key (used by a stochastic policy)
553
+ - policy_params is a pytree of trainable policy weights
554
+ - hyperparams is a pytree of (optional) fixed policy hyper-parameters
555
+ - subs is the dictionary of current fluent tensor values
556
+ - model_params is a dict of model hyperparameters.
557
+
558
+ The returned value of the returned function is:
559
+ - log is the dictionary of all trajectory information, including
560
+ constraints that were satisfied, errors, etc.
561
+
562
+ The arguments of the policy function is:
412
563
  - key is the PRNG key (used by a stochastic policy)
413
564
  - params is a pytree of trainable policy weights
414
565
  - hyperparams is a pytree of (optional) fixed policy hyper-parameters
415
566
  - step is the time index of the decision in the current rollout
416
567
  - states is a dict of tensors for the current observation.
417
-
568
+
418
569
  :param policy: a Jax compiled function for the policy as described above
419
570
  decision epoch, state dict, and an RNG key and returns an action dict
420
571
  :param n_steps: the rollout horizon
@@ -428,27 +579,32 @@ class JaxRDDLCompiler:
428
579
  rddl = self.rddl
429
580
  jax_step_fn = self.compile_transition(check_constraints, constraint_func)
430
581
 
582
+ # for POMDP only observ-fluents are assumed visible to the policy
583
+ if rddl.observ_fluents:
584
+ observed_vars = rddl.observ_fluents
585
+ else:
586
+ observed_vars = rddl.state_fluents
587
+
431
588
  # evaluate the step from the policy
432
589
  def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
433
590
  step, subs, model_params):
434
591
  states = {var: values
435
592
  for (var, values) in subs.items()
436
- if rddl.variable_types[var] == 'state-fluent'}
593
+ if var in observed_vars}
437
594
  actions = policy(key, policy_params, hyperparams, step, states)
438
595
  key, subkey = random.split(key)
439
- log = jax_step_fn(subkey, actions, subs, model_params)
440
- return log
596
+ subs, log = jax_step_fn(subkey, actions, subs, model_params)
597
+ return subs, log
441
598
 
442
599
  # do a batched step update from the policy
443
600
  def _jax_wrapped_batched_step_policy(carry, step):
444
601
  key, policy_params, hyperparams, subs, model_params = carry
445
602
  key, *subkeys = random.split(key, num=1 + n_batch)
446
603
  keys = jnp.asarray(subkeys)
447
- log = jax.vmap(
604
+ subs, log = jax.vmap(
448
605
  _jax_wrapped_single_step_policy,
449
606
  in_axes=(0, None, None, None, 0, None)
450
607
  )(keys, policy_params, hyperparams, step, subs, model_params)
451
- subs = log['pvar']
452
608
  carry = (key, policy_params, hyperparams, subs, model_params)
453
609
  return carry, log
454
610
 
@@ -467,7 +623,7 @@ class JaxRDDLCompiler:
467
623
  # error checks
468
624
  # ===========================================================================
469
625
 
470
- def print_jax(self) -> Dict[str, object]:
626
+ def print_jax(self) -> Dict[str, Any]:
471
627
  '''Returns a dictionary containing the string representations of all
472
628
  Jax compiled expressions from the RDDL file.
473
629
  '''
@@ -564,7 +720,7 @@ class JaxRDDLCompiler:
564
720
  }
565
721
 
566
722
  @staticmethod
567
- def get_error_codes(error):
723
+ def get_error_codes(error: int) -> List[int]:
568
724
  '''Given a compacted integer error flag from the execution of Jax, and
569
725
  decomposes it into individual error codes.
570
726
  '''
@@ -573,7 +729,7 @@ class JaxRDDLCompiler:
573
729
  return errors
574
730
 
575
731
  @staticmethod
576
- def get_error_messages(error):
732
+ def get_error_messages(error: int) -> List[str]:
577
733
  '''Given a compacted integer error flag from the execution of Jax, and
578
734
  decomposes it into error strings.
579
735
  '''
@@ -586,28 +742,40 @@ class JaxRDDLCompiler:
586
742
  # ===========================================================================
587
743
 
588
744
  def _unwrap(self, op, expr_id, info):
589
- sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
590
745
  jax_op, name = op, None
746
+ model_params, relaxed_list = info
591
747
  if isinstance(op, tuple):
592
748
  jax_op, param = op
593
749
  if param is not None:
594
750
  tags, values = param
751
+ sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
595
752
  if isinstance(tags, tuple):
596
753
  name = sep.join(tags)
597
754
  else:
598
755
  name = str(tags)
599
756
  name = f'{name}{sep}{expr_id}'
600
- if name in info:
601
- raise Exception(f'Model parameter {name} is already defined.')
602
- info[name] = values
757
+ if name in model_params:
758
+ raise RuntimeError(
759
+ f'Internal error: model parameter {name} is already defined.')
760
+ model_params[name] = (values, tags, expr_id, jax_op.__name__)
761
+ relaxed_list.append((param, expr_id, jax_op.__name__))
603
762
  return jax_op, name
604
763
 
605
- def get_ids_of_parameterized_expressions(self) -> List[int]:
606
- '''Returns a list of expression IDs that have tuning parameters.'''
607
- sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
608
- ids = [int(key.split(sep)[-1]) for key in self.model_params]
609
- return ids
610
-
764
+ def summarize_model_relaxations(self) -> str:
765
+ '''Returns a string of information about model relaxations in the
766
+ compiled model.'''
767
+ occurence_by_type = {}
768
+ for (_, expr_id, jax_op) in self.relaxations:
769
+ etype = self.traced.lookup(expr_id).etype
770
+ source = f'{etype[1]} ({etype[0]})'
771
+ sub = f'{source:<30} --> {jax_op}'
772
+ occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
773
+ col = "{:<80} {:<10}\n"
774
+ table = col.format('Substitution', 'Count')
775
+ for (sub, occurs) in occurence_by_type.items():
776
+ table += col.format(sub, occurs)
777
+ return table
778
+
611
779
  # ===========================================================================
612
780
  # expression compilation
613
781
  # ===========================================================================
@@ -640,7 +808,8 @@ class JaxRDDLCompiler:
640
808
  raise RDDLNotImplementedError(
641
809
  f'Internal error: expression type {expr} is not supported.\n' +
642
810
  print_stack_trace(expr))
643
-
811
+
812
+ # force type cast of tensor as required by caller
644
813
  if dtype is not None:
645
814
  jax_expr = self._jax_cast(jax_expr, dtype)
646
815
 
@@ -660,6 +829,17 @@ class JaxRDDLCompiler:
660
829
 
661
830
  return _jax_wrapped_cast
662
831
 
832
+ def _fix_dtype(self, value):
833
+ dtype = jnp.atleast_1d(value).dtype
834
+ if jnp.issubdtype(dtype, jnp.integer):
835
+ return self.INT
836
+ elif jnp.issubdtype(dtype, jnp.floating):
837
+ return self.REAL
838
+ elif jnp.issubdtype(dtype, jnp.bool_) or jnp.issubdtype(dtype, bool):
839
+ return bool
840
+ else:
841
+ raise TypeError(f'Invalid type {dtype} of {value}.')
842
+
663
843
  # ===========================================================================
664
844
  # leaves
665
845
  # ===========================================================================
@@ -669,7 +849,7 @@ class JaxRDDLCompiler:
669
849
  cached_value = self.traced.cached_sim_info(expr)
670
850
 
671
851
  def _jax_wrapped_constant(x, params, key):
672
- sample = jnp.asarray(cached_value)
852
+ sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
673
853
  return sample, key, NORMAL
674
854
 
675
855
  return _jax_wrapped_constant
@@ -693,7 +873,7 @@ class JaxRDDLCompiler:
693
873
  cached_value = cached_info
694
874
 
695
875
  def _jax_wrapped_object(x, params, key):
696
- sample = jnp.asarray(cached_value)
876
+ sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
697
877
  return sample, key, NORMAL
698
878
 
699
879
  return _jax_wrapped_object
@@ -702,7 +882,8 @@ class JaxRDDLCompiler:
702
882
  elif cached_info is None:
703
883
 
704
884
  def _jax_wrapped_pvar_scalar(x, params, key):
705
- sample = jnp.asarray(x[var])
885
+ value = x[var]
886
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
706
887
  return sample, key, NORMAL
707
888
 
708
889
  return _jax_wrapped_pvar_scalar
@@ -721,7 +902,8 @@ class JaxRDDLCompiler:
721
902
 
722
903
  def _jax_wrapped_pvar_tensor_nested(x, params, key):
723
904
  error = NORMAL
724
- sample = jnp.asarray(x[var])
905
+ value = x[var]
906
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
725
907
  new_slices = [None] * len(jax_nested_expr)
726
908
  for (i, jax_expr) in enumerate(jax_nested_expr):
727
909
  new_slices[i], key, err = jax_expr(x, params, key)
@@ -736,7 +918,8 @@ class JaxRDDLCompiler:
736
918
  else:
737
919
 
738
920
  def _jax_wrapped_pvar_tensor_non_nested(x, params, key):
739
- sample = jnp.asarray(x[var])
921
+ value = x[var]
922
+ sample = jnp.asarray(value, dtype=self._fix_dtype(value))
740
923
  if slices:
741
924
  sample = sample[slices]
742
925
  if axis:
@@ -795,16 +978,23 @@ class JaxRDDLCompiler:
795
978
 
796
979
  def _jax_arithmetic(self, expr, info):
797
980
  _, op = expr.etype
798
- valid_ops = self.ARITHMETIC_OPS
981
+
982
+ # if expression is non-fluent, always use the exact operation
983
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
984
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
985
+ negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
986
+ else:
987
+ valid_ops = self.ARITHMETIC_OPS
988
+ negative_op = self.NEGATIVE
799
989
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
800
-
990
+
991
+ # recursively compile arguments
801
992
  args = expr.args
802
993
  n = len(args)
803
-
804
994
  if n == 1 and op == '-':
805
995
  arg, = args
806
996
  jax_expr = self._jax(arg, info)
807
- jax_op, jax_param = self._unwrap(self.NEGATIVE, expr.id, info)
997
+ jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
808
998
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
809
999
 
810
1000
  elif n == 2:
@@ -819,29 +1009,42 @@ class JaxRDDLCompiler:
819
1009
 
820
1010
  def _jax_relational(self, expr, info):
821
1011
  _, op = expr.etype
822
- valid_ops = self.RELATIONAL_OPS
1012
+
1013
+ # if expression is non-fluent, always use the exact operation
1014
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1015
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
1016
+ else:
1017
+ valid_ops = self.RELATIONAL_OPS
823
1018
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
824
- JaxRDDLCompiler._check_num_args(expr, 2)
1019
+ jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
825
1020
 
1021
+ # recursively compile arguments
1022
+ JaxRDDLCompiler._check_num_args(expr, 2)
826
1023
  lhs, rhs = expr.args
827
1024
  jax_lhs = self._jax(lhs, info)
828
1025
  jax_rhs = self._jax(rhs, info)
829
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
830
1026
  return self._jax_binary(
831
1027
  jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
832
1028
 
833
1029
  def _jax_logical(self, expr, info):
834
1030
  _, op = expr.etype
835
- valid_ops = self.LOGICAL_OPS
1031
+
1032
+ # if expression is non-fluent, always use the exact operation
1033
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1034
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
1035
+ logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
1036
+ else:
1037
+ valid_ops = self.LOGICAL_OPS
1038
+ logical_not_op = self.LOGICAL_NOT
836
1039
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
837
1040
 
1041
+ # recursively compile arguments
838
1042
  args = expr.args
839
- n = len(args)
840
-
1043
+ n = len(args)
841
1044
  if n == 1 and op == '~':
842
1045
  arg, = args
843
1046
  jax_expr = self._jax(arg, info)
844
- jax_op, jax_param = self._unwrap(self.LOGICAL_NOT, expr.id, info)
1047
+ jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
845
1048
  return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
846
1049
 
847
1050
  elif n == 2:
@@ -856,17 +1059,21 @@ class JaxRDDLCompiler:
856
1059
 
857
1060
  def _jax_aggregation(self, expr, info):
858
1061
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
859
-
860
1062
  _, op = expr.etype
861
- valid_ops = self.AGGREGATION_OPS
1063
+
1064
+ # if expression is non-fluent, always use the exact operation
1065
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1066
+ valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
1067
+ else:
1068
+ valid_ops = self.AGGREGATION_OPS
862
1069
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
863
- is_floating = op not in self.AGGREGATION_BOOL
1070
+ jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
864
1071
 
1072
+ # recursively compile arguments
1073
+ is_floating = op not in self.AGGREGATION_BOOL
865
1074
  * _, arg = expr.args
866
- _, axes = self.traced.cached_sim_info(expr)
867
-
1075
+ _, axes = self.traced.cached_sim_info(expr)
868
1076
  jax_expr = self._jax(arg, info)
869
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
870
1077
 
871
1078
  def _jax_wrapped_aggregation(x, params, key):
872
1079
  sample, key, err = jax_expr(x, params, key)
@@ -884,21 +1091,28 @@ class JaxRDDLCompiler:
884
1091
  def _jax_functional(self, expr, info):
885
1092
  _, op = expr.etype
886
1093
 
887
- # unary function
888
- if op in self.KNOWN_UNARY:
1094
+ # if expression is non-fluent, always use the exact operation
1095
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1096
+ unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
1097
+ binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
1098
+ else:
1099
+ unary_ops = self.KNOWN_UNARY
1100
+ binary_ops = self.KNOWN_BINARY
1101
+
1102
+ # recursively compile arguments
1103
+ if op in unary_ops:
889
1104
  JaxRDDLCompiler._check_num_args(expr, 1)
890
1105
  arg, = expr.args
891
1106
  jax_expr = self._jax(arg, info)
892
- jax_op, jax_param = self._unwrap(self.KNOWN_UNARY[op], expr.id, info)
1107
+ jax_op, jax_param = self._unwrap(unary_ops[op], expr.id, info)
893
1108
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
894
1109
 
895
- # binary function
896
- elif op in self.KNOWN_BINARY:
1110
+ elif op in binary_ops:
897
1111
  JaxRDDLCompiler._check_num_args(expr, 2)
898
1112
  lhs, rhs = expr.args
899
1113
  jax_lhs = self._jax(lhs, info)
900
1114
  jax_rhs = self._jax(rhs, info)
901
- jax_op, jax_param = self._unwrap(self.KNOWN_BINARY[op], expr.id, info)
1115
+ jax_op, jax_param = self._unwrap(binary_ops[op], expr.id, info)
902
1116
  return self._jax_binary(
903
1117
  jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
904
1118
 
@@ -921,19 +1135,19 @@ class JaxRDDLCompiler:
921
1135
  f'Control operator {op} is not supported.\n' +
922
1136
  print_stack_trace(expr))
923
1137
 
924
- def _jax_if_helper(self):
925
-
926
- def _jax_wrapped_if_calc_exact(c, a, b, param):
927
- return jnp.where(c, a, b)
928
-
929
- return _jax_wrapped_if_calc_exact
930
-
931
1138
  def _jax_if(self, expr, info):
932
1139
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
933
1140
  JaxRDDLCompiler._check_num_args(expr, 3)
934
- jax_if, jax_param = self._unwrap(self._jax_if_helper(), expr.id, info)
1141
+ pred, if_true, if_false = expr.args
1142
+
1143
+ # if predicate is non-fluent, always use the exact operation
1144
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1145
+ if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
1146
+ else:
1147
+ if_op = self.IF_HELPER
1148
+ jax_if, jax_param = self._unwrap(if_op, expr.id, info)
935
1149
 
936
- pred, if_true, if_false = expr.args
1150
+ # recursively compile arguments
937
1151
  jax_pred = self._jax(pred, info)
938
1152
  jax_true = self._jax(if_true, info)
939
1153
  jax_false = self._jax(if_false, info)
@@ -951,23 +1165,20 @@ class JaxRDDLCompiler:
951
1165
 
952
1166
  return _jax_wrapped_if_then_else
953
1167
 
954
- def _jax_switch_helper(self):
955
-
956
- def _jax_wrapped_switch_calc_exact(pred, cases, param):
957
- pred = pred[jnp.newaxis, ...]
958
- sample = jnp.take_along_axis(cases, pred, axis=0)
959
- assert sample.shape[0] == 1
960
- return sample[0, ...]
961
-
962
- return _jax_wrapped_switch_calc_exact
963
-
964
1168
  def _jax_switch(self, expr, info):
965
- pred, *_ = expr.args
1169
+
1170
+ # if expression is non-fluent, always use the exact operation
1171
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1172
+ switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1173
+ else:
1174
+ switch_op = self.SWITCH_HELPER
1175
+ jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1176
+
1177
+ # recursively compile predicate
1178
+ pred, *_ = expr.args
966
1179
  jax_pred = self._jax(pred, info)
967
- jax_switch, jax_param = self._unwrap(
968
- self._jax_switch_helper(), expr.id, info)
969
1180
 
970
- # wrap cases as JAX expressions
1181
+ # recursively compile cases
971
1182
  cases, default = self.traced.cached_sim_info(expr)
972
1183
  jax_default = None if default is None else self._jax(default, info)
973
1184
  jax_cases = [(jax_default if _case is None else self._jax(_case, info))
@@ -983,7 +1194,8 @@ class JaxRDDLCompiler:
983
1194
  for (i, jax_case) in enumerate(jax_cases):
984
1195
  sample_cases[i], key, err_case = jax_case(x, params, key)
985
1196
  err |= err_case
986
- sample_cases = jnp.asarray(sample_cases)
1197
+ sample_cases = jnp.asarray(
1198
+ sample_cases, dtype=self._fix_dtype(sample_cases))
987
1199
 
988
1200
  # predicate (enum) is an integer - use it to extract from case array
989
1201
  param = params.get(jax_param, None)
@@ -1179,30 +1391,28 @@ class JaxRDDLCompiler:
1179
1391
  scale, key, err2 = jax_scale(x, params, key)
1180
1392
  key, subkey = random.split(key)
1181
1393
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1182
- sample = scale * jnp.power(-jnp.log1p(-U), 1.0 / shape)
1394
+ sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
1183
1395
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1184
1396
  err = err1 | err2 | (out_of_bounds * ERR)
1185
1397
  return sample, key, err
1186
1398
 
1187
1399
  return _jax_wrapped_distribution_weibull
1188
1400
 
1189
- def _jax_bernoulli_helper(self):
1190
-
1191
- def _jax_wrapped_calc_bernoulli_exact(key, prob, param):
1192
- return random.bernoulli(key, prob)
1193
-
1194
- return _jax_wrapped_calc_bernoulli_exact
1195
-
1196
1401
  def _jax_bernoulli(self, expr, info):
1197
1402
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1198
1403
  JaxRDDLCompiler._check_num_args(expr, 1)
1199
- jax_bern, jax_param = self._unwrap(
1200
- self._jax_bernoulli_helper(), expr.id, info)
1201
-
1202
1404
  arg_prob, = expr.args
1405
+
1406
+ # if probability is non-fluent, always use the exact operation
1407
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1408
+ bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
1409
+ else:
1410
+ bern_op = self.BERNOULLI_HELPER
1411
+ jax_bern, jax_param = self._unwrap(bern_op, expr.id, info)
1412
+
1413
+ # recursively compile arguments
1203
1414
  jax_prob = self._jax(arg_prob, info)
1204
1415
 
1205
- # uses the implicit JAX subroutine
1206
1416
  def _jax_wrapped_distribution_bernoulli(x, params, key):
1207
1417
  prob, key, err = jax_prob(x, params, key)
1208
1418
  key, subkey = random.split(key)
@@ -1266,8 +1476,8 @@ class JaxRDDLCompiler:
1266
1476
  def _jax_wrapped_distribution_binomial(x, params, key):
1267
1477
  trials, key, err2 = jax_trials(x, params, key)
1268
1478
  prob, key, err1 = jax_prob(x, params, key)
1269
- trials = jnp.asarray(trials, self.REAL)
1270
- prob = jnp.asarray(prob, self.REAL)
1479
+ trials = jnp.asarray(trials, dtype=self.REAL)
1480
+ prob = jnp.asarray(prob, dtype=self.REAL)
1271
1481
  key, subkey = random.split(key)
1272
1482
  dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
1273
1483
  sample = dist.sample(seed=subkey).astype(self.INT)
@@ -1290,11 +1500,10 @@ class JaxRDDLCompiler:
1290
1500
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1291
1501
  trials, key, err2 = jax_trials(x, params, key)
1292
1502
  prob, key, err1 = jax_prob(x, params, key)
1293
- trials = jnp.asarray(trials, self.REAL)
1294
- prob = jnp.asarray(prob, self.REAL)
1503
+ trials = jnp.asarray(trials, dtype=self.REAL)
1504
+ prob = jnp.asarray(prob, dtype=self.REAL)
1295
1505
  key, subkey = random.split(key)
1296
- dist = tfp.distributions.NegativeBinomial(
1297
- total_count=trials, probs=prob)
1506
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1298
1507
  sample = dist.sample(seed=subkey).astype(self.INT)
1299
1508
  out_of_bounds = jnp.logical_not(jnp.all(
1300
1509
  (prob >= 0) & (prob <= 1) & (trials > 0)))
@@ -1316,7 +1525,7 @@ class JaxRDDLCompiler:
1316
1525
  shape, key, err1 = jax_shape(x, params, key)
1317
1526
  rate, key, err2 = jax_rate(x, params, key)
1318
1527
  key, subkey = random.split(key)
1319
- sample = random.beta(key=subkey, a=shape, b=rate)
1528
+ sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
1320
1529
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
1321
1530
  err = err1 | err2 | (out_of_bounds * ERR)
1322
1531
  return sample, key, err
@@ -1325,23 +1534,35 @@ class JaxRDDLCompiler:
1325
1534
 
1326
1535
  def _jax_geometric(self, expr, info):
1327
1536
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1328
- JaxRDDLCompiler._check_num_args(expr, 1)
1329
-
1537
+ JaxRDDLCompiler._check_num_args(expr, 1)
1330
1538
  arg_prob, = expr.args
1331
1539
  jax_prob = self._jax(arg_prob, info)
1332
- floor_op, jax_param = self._unwrap(
1333
- self.KNOWN_UNARY['floor'], expr.id, info)
1334
1540
 
1335
- # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1336
- def _jax_wrapped_distribution_geometric(x, params, key):
1337
- prob, key, err = jax_prob(x, params, key)
1338
- key, subkey = random.split(key)
1339
- U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1340
- param = params.get(jax_param, None)
1341
- sample = floor_op(jnp.log1p(-U) / jnp.log1p(-prob), param) + 1
1342
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1343
- err |= (out_of_bounds * ERR)
1344
- return sample, key, err
1541
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1542
+
1543
+ # prob is non-fluent: do not reparameterize
1544
+ def _jax_wrapped_distribution_geometric(x, params, key):
1545
+ prob, key, err = jax_prob(x, params, key)
1546
+ key, subkey = random.split(key)
1547
+ sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1548
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1549
+ err |= (out_of_bounds * ERR)
1550
+ return sample, key, err
1551
+
1552
+ else:
1553
+ floor_op, jax_param = self._unwrap(
1554
+ self.KNOWN_UNARY['floor'], expr.id, info)
1555
+
1556
+ # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1557
+ def _jax_wrapped_distribution_geometric(x, params, key):
1558
+ prob, key, err = jax_prob(x, params, key)
1559
+ key, subkey = random.split(key)
1560
+ U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1561
+ param = params.get(jax_param, None)
1562
+ sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
1563
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1564
+ err |= (out_of_bounds * ERR)
1565
+ return sample, key, err
1345
1566
 
1346
1567
  return _jax_wrapped_distribution_geometric
1347
1568
 
@@ -1359,7 +1580,7 @@ class JaxRDDLCompiler:
1359
1580
  shape, key, err1 = jax_shape(x, params, key)
1360
1581
  scale, key, err2 = jax_scale(x, params, key)
1361
1582
  key, subkey = random.split(key)
1362
- sample = scale * random.pareto(key=subkey, b=shape)
1583
+ sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
1363
1584
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1364
1585
  err = err1 | err2 | (out_of_bounds * ERR)
1365
1586
  return sample, key, err
@@ -1377,7 +1598,8 @@ class JaxRDDLCompiler:
1377
1598
  def _jax_wrapped_distribution_t(x, params, key):
1378
1599
  df, key, err = jax_df(x, params, key)
1379
1600
  key, subkey = random.split(key)
1380
- sample = random.t(key=subkey, df=df, shape=jnp.shape(df))
1601
+ sample = random.t(
1602
+ key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1381
1603
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1382
1604
  err |= (out_of_bounds * ERR)
1383
1605
  return sample, key, err
@@ -1464,7 +1686,7 @@ class JaxRDDLCompiler:
1464
1686
  scale, key, err2 = jax_scale(x, params, key)
1465
1687
  key, subkey = random.split(key)
1466
1688
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1467
- sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
1689
+ sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
1468
1690
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1469
1691
  err = err1 | err2 | (out_of_bounds * ERR)
1470
1692
  return sample, key, err
@@ -1516,25 +1738,21 @@ class JaxRDDLCompiler:
1516
1738
  # random variables with enum support
1517
1739
  # ===========================================================================
1518
1740
 
1519
- def _jax_discrete_helper(self):
1520
-
1521
- def _jax_wrapped_discrete_calc_exact(key, prob, param):
1522
- logits = jnp.log(prob)
1523
- sample = random.categorical(key=key, logits=logits, axis=-1)
1524
- out_of_bounds = jnp.logical_not(jnp.logical_and(
1525
- jnp.all(prob >= 0),
1526
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1527
- return sample, out_of_bounds
1528
-
1529
- return _jax_wrapped_discrete_calc_exact
1530
-
1531
1741
  def _jax_discrete(self, expr, info, unnorm):
1532
1742
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
1533
1743
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1534
- jax_discrete, jax_param = self._unwrap(
1535
- self._jax_discrete_helper(), expr.id, info)
1536
-
1537
1744
  ordered_args = self.traced.cached_sim_info(expr)
1745
+
1746
+ # if all probabilities are non-fluent, then always sample exact
1747
+ has_fluent_arg = any(self.traced.cached_is_fluent(arg)
1748
+ for arg in ordered_args)
1749
+ if self.compile_non_fluent_exact and not has_fluent_arg:
1750
+ discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1751
+ else:
1752
+ discrete_op = self.DISCRETE_HELPER
1753
+ jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1754
+
1755
+ # compile probability expressions
1538
1756
  jax_probs = [self._jax(arg, info) for arg in ordered_args]
1539
1757
 
1540
1758
  def _jax_wrapped_distribution_discrete(x, params, key):
@@ -1561,12 +1779,18 @@ class JaxRDDLCompiler:
1561
1779
 
1562
1780
  def _jax_discrete_pvar(self, expr, info, unnorm):
1563
1781
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1564
- JaxRDDLCompiler._check_num_args(expr, 1)
1565
- jax_discrete, jax_param = self._unwrap(
1566
- self._jax_discrete_helper(), expr.id, info)
1567
-
1782
+ JaxRDDLCompiler._check_num_args(expr, 2)
1568
1783
  _, args = expr.args
1569
1784
  arg, = args
1785
+
1786
+ # if probabilities are non-fluent, then always sample exact
1787
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
1788
+ discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1789
+ else:
1790
+ discrete_op = self.DISCRETE_HELPER
1791
+ jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1792
+
1793
+ # compile probability function
1570
1794
  jax_probs = self._jax(arg, info)
1571
1795
 
1572
1796
  def _jax_wrapped_distribution_discrete_pvar(x, params, key):
@@ -1687,7 +1911,7 @@ class JaxRDDLCompiler:
1687
1911
  out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
1688
1912
  error |= (out_of_bounds * ERR)
1689
1913
  key, subkey = random.split(key)
1690
- Gamma = random.gamma(key=subkey, a=alpha)
1914
+ Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
1691
1915
  sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
1692
1916
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1693
1917
  return sample, key, error
@@ -1706,8 +1930,8 @@ class JaxRDDLCompiler:
1706
1930
  def _jax_wrapped_distribution_multinomial(x, params, key):
1707
1931
  trials, key, err1 = jax_trials(x, params, key)
1708
1932
  prob, key, err2 = jax_prob(x, params, key)
1709
- trials = jnp.asarray(trials, self.REAL)
1710
- prob = jnp.asarray(prob, self.REAL)
1933
+ trials = jnp.asarray(trials, dtype=self.REAL)
1934
+ prob = jnp.asarray(prob, dtype=self.REAL)
1711
1935
  key, subkey = random.split(key)
1712
1936
  dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
1713
1937
  sample = dist.sample(seed=subkey).astype(self.INT)