pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,8 @@ from pyRDDLGym.core.debug.exception import (
21
21
  from pyRDDLGym.core.debug.logger import Logger
22
22
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
23
23
 
24
+ from pyRDDLGym_jax.core.logic import ExactLogic
25
+
24
26
  # more robust approach - if user does not have this or broken try to continue
25
27
  try:
26
28
  from tensorflow_probability.substrates import jax as tfp
@@ -32,109 +34,6 @@ except Exception:
32
34
  tfp = None
33
35
 
34
36
 
35
- # ===========================================================================
36
- # EXACT RDDL TO JAX COMPILATION RULES
37
- # ===========================================================================
38
-
39
- def _function_unary_exact_named(op, name):
40
-
41
- def _jax_wrapped_unary_fn_exact(x, param):
42
- return op(x)
43
-
44
- return _jax_wrapped_unary_fn_exact
45
-
46
-
47
- def _function_unary_exact_named_gamma():
48
-
49
- def _jax_wrapped_unary_gamma_exact(x, param):
50
- return jnp.exp(scipy.special.gammaln(x))
51
-
52
- return _jax_wrapped_unary_gamma_exact
53
-
54
-
55
- def _function_binary_exact_named(op, name):
56
-
57
- def _jax_wrapped_binary_fn_exact(x, y, param):
58
- return op(x, y)
59
-
60
- return _jax_wrapped_binary_fn_exact
61
-
62
-
63
- def _function_binary_exact_named_implies():
64
-
65
- def _jax_wrapped_binary_implies_exact(x, y, param):
66
- return jnp.logical_or(jnp.logical_not(x), y)
67
-
68
- return _jax_wrapped_binary_implies_exact
69
-
70
-
71
- def _function_binary_exact_named_log():
72
-
73
- def _jax_wrapped_binary_log_exact(x, y, param):
74
- return jnp.log(x) / jnp.log(y)
75
-
76
- return _jax_wrapped_binary_log_exact
77
-
78
-
79
- def _function_aggregation_exact_named(op, name):
80
-
81
- def _jax_wrapped_aggregation_fn_exact(x, axis, param):
82
- return op(x, axis=axis)
83
-
84
- return _jax_wrapped_aggregation_fn_exact
85
-
86
-
87
- def _function_if_exact_named():
88
-
89
- def _jax_wrapped_if_exact(c, a, b, param):
90
- return jnp.where(c > 0.5, a, b)
91
-
92
- return _jax_wrapped_if_exact
93
-
94
-
95
- def _function_switch_exact_named():
96
-
97
- def _jax_wrapped_switch_exact(pred, cases, param):
98
- pred = pred[jnp.newaxis, ...]
99
- sample = jnp.take_along_axis(cases, pred, axis=0)
100
- assert sample.shape[0] == 1
101
- return sample[0, ...]
102
-
103
- return _jax_wrapped_switch_exact
104
-
105
-
106
- def _function_bernoulli_exact_named():
107
-
108
- def _jax_wrapped_bernoulli_exact(key, prob, param):
109
- return random.bernoulli(key, prob)
110
-
111
- return _jax_wrapped_bernoulli_exact
112
-
113
-
114
- def _function_discrete_exact_named():
115
-
116
- def _jax_wrapped_discrete_exact(key, prob, param):
117
- return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
118
-
119
- return _jax_wrapped_discrete_exact
120
-
121
-
122
- def _function_poisson_exact_named():
123
-
124
- def _jax_wrapped_poisson_exact(key, rate, param):
125
- return random.poisson(key=key, lam=rate, dtype=jnp.int64)
126
-
127
- return _jax_wrapped_poisson_exact
128
-
129
-
130
- def _function_geometric_exact_named():
131
-
132
- def _jax_wrapped_geometric_exact(key, prob, param):
133
- return random.geometric(key=key, p=prob, dtype=jnp.int64)
134
-
135
- return _jax_wrapped_geometric_exact
136
-
137
-
138
37
  class JaxRDDLCompiler:
139
38
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
140
39
  All operations are identical to their numpy equivalents.
@@ -145,88 +44,96 @@ class JaxRDDLCompiler:
145
44
  # ===========================================================================
146
45
  # EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
147
46
  # ===========================================================================
148
-
149
- EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
150
47
 
151
- EXACT_RDDL_TO_JAX_ARITHMETIC = {
152
- '+': _function_binary_exact_named(jnp.add, 'add'),
153
- '-': _function_binary_exact_named(jnp.subtract, 'subtract'),
154
- '*': _function_binary_exact_named(jnp.multiply, 'multiply'),
155
- '/': _function_binary_exact_named(jnp.divide, 'divide')
156
- }
48
+ @staticmethod
49
+ def wrap_logic(func):
50
+ def exact_func(id, init_params):
51
+ return func
52
+ return exact_func
157
53
 
54
+ EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
55
+ EXACT_RDDL_TO_JAX_ARITHMETIC = {
56
+ '+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
57
+ '-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
58
+ '*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
59
+ '/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
60
+ }
61
+
158
62
  EXACT_RDDL_TO_JAX_RELATIONAL = {
159
- '>=': _function_binary_exact_named(jnp.greater_equal, 'greater_equal'),
160
- '<=': _function_binary_exact_named(jnp.less_equal, 'less_equal'),
161
- '<': _function_binary_exact_named(jnp.less, 'less'),
162
- '>': _function_binary_exact_named(jnp.greater, 'greater'),
163
- '==': _function_binary_exact_named(jnp.equal, 'equal'),
164
- '~=': _function_binary_exact_named(jnp.not_equal, 'not_equal')
165
- }
166
-
63
+ '>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
+ '<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
65
+ '<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
66
+ '>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
67
+ '==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
68
+ '~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
69
+ }
70
+
71
+ EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
167
72
  EXACT_RDDL_TO_JAX_LOGICAL = {
168
- '^': _function_binary_exact_named(jnp.logical_and, 'and'),
169
- '&': _function_binary_exact_named(jnp.logical_and, 'and'),
170
- '|': _function_binary_exact_named(jnp.logical_or, 'or'),
171
- '~': _function_binary_exact_named(jnp.logical_xor, 'xor'),
172
- '=>': _function_binary_exact_named_implies(),
173
- '<=>': _function_binary_exact_named(jnp.equal, 'iff')
174
- }
175
-
176
- EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
73
+ '^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
74
+ '&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
75
+ '|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
76
+ '~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
+ '=>': wrap_logic(ExactLogic.exact_binary_implies),
78
+ '<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
79
+ }
177
80
 
178
81
  EXACT_RDDL_TO_JAX_AGGREGATION = {
179
- 'sum': _function_aggregation_exact_named(jnp.sum, 'sum'),
180
- 'avg': _function_aggregation_exact_named(jnp.mean, 'avg'),
181
- 'prod': _function_aggregation_exact_named(jnp.prod, 'prod'),
182
- 'minimum': _function_aggregation_exact_named(jnp.min, 'minimum'),
183
- 'maximum': _function_aggregation_exact_named(jnp.max, 'maximum'),
184
- 'forall': _function_aggregation_exact_named(jnp.all, 'forall'),
185
- 'exists': _function_aggregation_exact_named(jnp.any, 'exists'),
186
- 'argmin': _function_aggregation_exact_named(jnp.argmin, 'argmin'),
187
- 'argmax': _function_aggregation_exact_named(jnp.argmax, 'argmax')
82
+ 'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
83
+ 'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
84
+ 'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
85
+ 'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
86
+ 'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
87
+ 'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
88
+ 'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
89
+ 'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
90
+ 'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
188
91
  }
189
92
 
190
93
  EXACT_RDDL_TO_JAX_UNARY = {
191
- 'abs': _function_unary_exact_named(jnp.abs, 'abs'),
192
- 'sgn': _function_unary_exact_named(jnp.sign, 'sgn'),
193
- 'round': _function_unary_exact_named(jnp.round, 'round'),
194
- 'floor': _function_unary_exact_named(jnp.floor, 'floor'),
195
- 'ceil': _function_unary_exact_named(jnp.ceil, 'ceil'),
196
- 'cos': _function_unary_exact_named(jnp.cos, 'cos'),
197
- 'sin': _function_unary_exact_named(jnp.sin, 'sin'),
198
- 'tan': _function_unary_exact_named(jnp.tan, 'tan'),
199
- 'acos': _function_unary_exact_named(jnp.arccos, 'acos'),
200
- 'asin': _function_unary_exact_named(jnp.arcsin, 'asin'),
201
- 'atan': _function_unary_exact_named(jnp.arctan, 'atan'),
202
- 'cosh': _function_unary_exact_named(jnp.cosh, 'cosh'),
203
- 'sinh': _function_unary_exact_named(jnp.sinh, 'sinh'),
204
- 'tanh': _function_unary_exact_named(jnp.tanh, 'tanh'),
205
- 'exp': _function_unary_exact_named(jnp.exp, 'exp'),
206
- 'ln': _function_unary_exact_named(jnp.log, 'ln'),
207
- 'sqrt': _function_unary_exact_named(jnp.sqrt, 'sqrt'),
208
- 'lngamma': _function_unary_exact_named(scipy.special.gammaln, 'lngamma'),
209
- 'gamma': _function_unary_exact_named_gamma()
210
- }
211
-
94
+ 'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
95
+ 'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
96
+ 'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
97
+ 'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
98
+ 'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
99
+ 'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
100
+ 'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
101
+ 'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
102
+ 'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
103
+ 'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
104
+ 'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
105
+ 'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
106
+ 'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
107
+ 'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
108
+ 'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
109
+ 'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
110
+ 'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
111
+ 'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
+ 'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
113
+ }
114
+
115
+ @staticmethod
116
+ def _jax_wrapped_calc_log_exact(x, y, params):
117
+ return jnp.log(x) / jnp.log(y), params
118
+
212
119
  EXACT_RDDL_TO_JAX_BINARY = {
213
- 'div': _function_binary_exact_named(jnp.floor_divide, 'div'),
214
- 'mod': _function_binary_exact_named(jnp.mod, 'mod'),
215
- 'fmod': _function_binary_exact_named(jnp.mod, 'fmod'),
216
- 'min': _function_binary_exact_named(jnp.minimum, 'min'),
217
- 'max': _function_binary_exact_named(jnp.maximum, 'max'),
218
- 'pow': _function_binary_exact_named(jnp.power, 'pow'),
219
- 'log': _function_binary_exact_named_log(),
220
- 'hypot': _function_binary_exact_named(jnp.hypot, 'hypot'),
120
+ 'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
+ 'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
122
+ 'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
123
+ 'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
124
+ 'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
125
+ 'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
126
+ 'log': wrap_logic(_jax_wrapped_calc_log_exact),
127
+ 'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
221
128
  }
222
129
 
223
- EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
224
- EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
130
+ EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
131
+ EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
225
132
 
226
- EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
227
- EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
228
- EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
229
- EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
133
+ EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
134
+ EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
135
+ EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
136
+ EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
230
137
 
231
138
  def __init__(self, rddl: RDDLLiftedModel,
232
139
  allow_synchronous_state: bool=True,
@@ -251,18 +158,17 @@ class JaxRDDLCompiler:
251
158
  if use64bit:
252
159
  self.INT = jnp.int64
253
160
  self.REAL = jnp.float64
254
- jax.config.update('jax_enable_x64', True)
255
161
  else:
256
162
  self.INT = jnp.int32
257
163
  self.REAL = jnp.float32
258
- jax.config.update('jax_enable_x64', False)
164
+ jax.config.update('jax_enable_x64', use64bit)
259
165
  self.ONE = jnp.asarray(1, dtype=self.INT)
260
166
  self.JAX_TYPES = {
261
167
  'int': self.INT,
262
168
  'real': self.REAL,
263
169
  'bool': bool
264
170
  }
265
-
171
+
266
172
  # compile initial values
267
173
  initializer = RDDLValueInitializer(rddl)
268
174
  self.init_values = initializer.initialize()
@@ -314,15 +220,13 @@ class JaxRDDLCompiler:
314
220
  to the log file
315
221
  :param heading: the heading to print before compilation information
316
222
  '''
317
- info = ({}, [])
318
- self.invariants = self._compile_constraints(self.rddl.invariants, info)
319
- self.preconditions = self._compile_constraints(self.rddl.preconditions, info)
320
- self.terminations = self._compile_constraints(self.rddl.terminations, info)
321
- self.cpfs = self._compile_cpfs(info)
322
- self.reward = self._compile_reward(info)
323
- self.model_params = {key: value
324
- for (key, (value, *_)) in info[0].items()}
325
- self.relaxations = info[1]
223
+ init_params = {}
224
+ self.invariants = self._compile_constraints(self.rddl.invariants, init_params)
225
+ self.preconditions = self._compile_constraints(self.rddl.preconditions, init_params)
226
+ self.terminations = self._compile_constraints(self.rddl.terminations, init_params)
227
+ self.cpfs = self._compile_cpfs(init_params)
228
+ self.reward = self._compile_reward(init_params)
229
+ self.model_params = init_params
326
230
 
327
231
  if log_jax_expr and self.logger is not None:
328
232
  printed = self.print_jax()
@@ -332,7 +236,7 @@ class JaxRDDLCompiler:
332
236
  printed_invariants = '\n\n'.join(v for v in printed['invariants'])
333
237
  printed_preconds = '\n\n'.join(v for v in printed['preconditions'])
334
238
  printed_terminals = '\n\n'.join(v for v in printed['terminations'])
335
- printed_params = '\n'.join(f'{k}: {v}' for (k, v) in info.items())
239
+ printed_params = '\n'.join(f'{k}: {v}' for (k, v) in init_params.items())
336
240
  message = (
337
241
  f'[info] {heading}\n'
338
242
  f'[info] compiled JAX CPFs:\n\n'
@@ -350,21 +254,21 @@ class JaxRDDLCompiler:
350
254
  )
351
255
  self.logger.log(message)
352
256
 
353
- def _compile_constraints(self, constraints, info):
354
- return [self._jax(expr, info, dtype=bool) for expr in constraints]
257
+ def _compile_constraints(self, constraints, init_params):
258
+ return [self._jax(expr, init_params, dtype=bool) for expr in constraints]
355
259
 
356
- def _compile_cpfs(self, info):
260
+ def _compile_cpfs(self, init_params):
357
261
  jax_cpfs = {}
358
262
  for cpfs in self.levels.values():
359
263
  for cpf in cpfs:
360
264
  _, expr = self.rddl.cpfs[cpf]
361
265
  prange = self.rddl.variable_ranges[cpf]
362
266
  dtype = self.JAX_TYPES.get(prange, self.INT)
363
- jax_cpfs[cpf] = self._jax(expr, info, dtype=dtype)
267
+ jax_cpfs[cpf] = self._jax(expr, init_params, dtype=dtype)
364
268
  return jax_cpfs
365
269
 
366
- def _compile_reward(self, info):
367
- return self._jax(self.rddl.reward, info, dtype=self.REAL)
270
+ def _compile_reward(self, init_params):
271
+ return self._jax(self.rddl.reward, init_params, dtype=self.REAL)
368
272
 
369
273
  def _extract_inequality_constraint(self, expr):
370
274
  result = []
@@ -392,7 +296,7 @@ class JaxRDDLCompiler:
392
296
  result.extend(self._extract_equality_constraint(arg))
393
297
  return result
394
298
 
395
- def _jax_nonlinear_constraints(self):
299
+ def _jax_nonlinear_constraints(self, init_params):
396
300
  rddl = self.rddl
397
301
 
398
302
  # extract the non-box inequality constraints on actions
@@ -402,12 +306,12 @@ class JaxRDDLCompiler:
402
306
  if not self.constraints.is_box_preconditions[i]]
403
307
 
404
308
  # compile them to JAX and write as h(s, a) <= 0
405
- op = self.ARITHMETIC_OPS['-']
309
+ jax_op = ExactLogic.exact_binary_function(jnp.subtract)
406
310
  jax_inequalities = []
407
311
  for (left, right) in inequalities:
408
- jax_lhs = self._jax(left, {})
409
- jax_rhs = self._jax(right, {})
410
- jax_constr = self._jax_binary(jax_lhs, jax_rhs, op, '', at_least_int=True)
312
+ jax_lhs = self._jax(left, init_params)
313
+ jax_rhs = self._jax(right, init_params)
314
+ jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
411
315
  jax_inequalities.append(jax_constr)
412
316
 
413
317
  # extract the non-box equality constraints on actions
@@ -419,15 +323,16 @@ class JaxRDDLCompiler:
419
323
  # compile them to JAX and write as g(s, a) == 0
420
324
  jax_equalities = []
421
325
  for (left, right) in equalities:
422
- jax_lhs = self._jax(left, {})
423
- jax_rhs = self._jax(right, {})
424
- jax_constr = self._jax_binary(jax_lhs, jax_rhs, op, '', at_least_int=True)
326
+ jax_lhs = self._jax(left, init_params)
327
+ jax_rhs = self._jax(right, init_params)
328
+ jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
425
329
  jax_equalities.append(jax_constr)
426
330
 
427
331
  return jax_inequalities, jax_equalities
428
332
 
429
333
  def compile_transition(self, check_constraints: bool=False,
430
- constraint_func: bool=False) -> Callable:
334
+ constraint_func: bool=False,
335
+ init_params_constr: Dict[str, Any]={}) -> Callable:
431
336
  '''Compiles the current RDDL model into a JAX transition function that
432
337
  samples the next state.
433
338
 
@@ -467,13 +372,12 @@ class JaxRDDLCompiler:
467
372
  '''
468
373
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
469
374
  rddl = self.rddl
470
- reward_fn, cpfs = self.reward, self.cpfs
471
- preconds, invariants, terminals = \
472
- self.preconditions, self.invariants, self.terminations
375
+ reward_fn, cpfs, preconds, invariants, terminals = \
376
+ self.reward, self.cpfs, self.preconditions, self.invariants, self.terminations
473
377
 
474
378
  # compile constraint information
475
379
  if constraint_func:
476
- inequality_fns, equality_fns = self._jax_nonlinear_constraints()
380
+ inequality_fns, equality_fns = self._jax_nonlinear_constraints(init_params_constr)
477
381
  else:
478
382
  inequality_fns, equality_fns = None, None
479
383
 
@@ -486,7 +390,7 @@ class JaxRDDLCompiler:
486
390
  precond_check = True
487
391
  if check_constraints:
488
392
  for precond in preconds:
489
- sample, key, err = precond(subs, model_params, key)
393
+ sample, key, err, model_params = precond(subs, model_params, key)
490
394
  precond_check = jnp.logical_and(precond_check, sample)
491
395
  errors |= err
492
396
 
@@ -494,21 +398,21 @@ class JaxRDDLCompiler:
494
398
  inequalities, equalities = [], []
495
399
  if constraint_func:
496
400
  for constraint in inequality_fns:
497
- sample, key, err = constraint(subs, model_params, key)
401
+ sample, key, err, model_params = constraint(subs, model_params, key)
498
402
  inequalities.append(sample)
499
403
  errors |= err
500
404
  for constraint in equality_fns:
501
- sample, key, err = constraint(subs, model_params, key)
405
+ sample, key, err, model_params = constraint(subs, model_params, key)
502
406
  equalities.append(sample)
503
407
  errors |= err
504
408
 
505
409
  # calculate CPFs in topological order
506
410
  for (name, cpf) in cpfs.items():
507
- subs[name], key, err = cpf(subs, model_params, key)
411
+ subs[name], key, err, model_params = cpf(subs, model_params, key)
508
412
  errors |= err
509
413
 
510
414
  # calculate the immediate reward
511
- reward, key, err = reward_fn(subs, model_params, key)
415
+ reward, key, err, model_params = reward_fn(subs, model_params, key)
512
416
  errors |= err
513
417
 
514
418
  # calculate fluent values
@@ -523,7 +427,7 @@ class JaxRDDLCompiler:
523
427
  invariant_check = True
524
428
  if check_constraints:
525
429
  for invariant in invariants:
526
- sample, key, err = invariant(subs, model_params, key)
430
+ sample, key, err, model_params = invariant(subs, model_params, key)
527
431
  invariant_check = jnp.logical_and(invariant_check, sample)
528
432
  errors |= err
529
433
 
@@ -531,7 +435,7 @@ class JaxRDDLCompiler:
531
435
  terminated_check = False
532
436
  if check_constraints:
533
437
  for terminal in terminals:
534
- sample, key, err = terminal(subs, model_params, key)
438
+ sample, key, err, model_params = terminal(subs, model_params, key)
535
439
  terminated_check = jnp.logical_or(terminated_check, sample)
536
440
  errors |= err
537
441
 
@@ -548,7 +452,7 @@ class JaxRDDLCompiler:
548
452
  log['inequalities'] = inequalities
549
453
  log['equalities'] = equalities
550
454
 
551
- return subs, log
455
+ return subs, log, model_params
552
456
 
553
457
  return _jax_wrapped_single_step
554
458
 
@@ -556,7 +460,8 @@ class JaxRDDLCompiler:
556
460
  n_steps: int,
557
461
  n_batch: int,
558
462
  check_constraints: bool=False,
559
- constraint_func: bool=False) -> Callable:
463
+ constraint_func: bool=False,
464
+ init_params_constr: Dict[str, Any]={}) -> Callable:
560
465
  '''Compiles the current RDDL model into a JAX transition function that
561
466
  samples trajectories with a fixed horizon from a policy.
562
467
 
@@ -569,7 +474,8 @@ class JaxRDDLCompiler:
569
474
 
570
475
  The returned value of the returned function is:
571
476
  - log is the dictionary of all trajectory information, including
572
- constraints that were satisfied, errors, etc.
477
+ constraints that were satisfied, errors, etc
478
+ - model_params is the final set of model parameters.
573
479
 
574
480
  The arguments of the policy function is:
575
481
  - key is the PRNG key (used by a stochastic policy)
@@ -589,7 +495,8 @@ class JaxRDDLCompiler:
589
495
  in addition to the usual outputs
590
496
  '''
591
497
  rddl = self.rddl
592
- jax_step_fn = self.compile_transition(check_constraints, constraint_func)
498
+ jax_step_fn = self.compile_transition(
499
+ check_constraints, constraint_func, init_params_constr)
593
500
 
594
501
  # for POMDP only observ-fluents are assumed visible to the policy
595
502
  if rddl.observ_fluents:
@@ -605,18 +512,19 @@ class JaxRDDLCompiler:
605
512
  if var in observed_vars}
606
513
  actions = policy(key, policy_params, hyperparams, step, states)
607
514
  key, subkey = random.split(key)
608
- subs, log = jax_step_fn(subkey, actions, subs, model_params)
609
- return subs, log
515
+ return jax_step_fn(subkey, actions, subs, model_params)
610
516
 
611
517
  # do a batched step update from the policy
518
+ # TODO: come up with a better way to reduce the model_param batch dim
612
519
  def _jax_wrapped_batched_step_policy(carry, step):
613
520
  key, policy_params, hyperparams, subs, model_params = carry
614
521
  key, *subkeys = random.split(key, num=1 + n_batch)
615
522
  keys = jnp.asarray(subkeys)
616
- subs, log = jax.vmap(
523
+ subs, log, model_params = jax.vmap(
617
524
  _jax_wrapped_single_step_policy,
618
525
  in_axes=(0, None, None, None, 0, None)
619
526
  )(keys, policy_params, hyperparams, step, subs, model_params)
527
+ model_params = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_params)
620
528
  carry = (key, policy_params, hyperparams, subs, model_params)
621
529
  return carry, log
622
530
 
@@ -625,14 +533,15 @@ class JaxRDDLCompiler:
625
533
  subs, model_params):
626
534
  start = (key, policy_params, hyperparams, subs, model_params)
627
535
  steps = jnp.arange(n_steps)
628
- _, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
536
+ end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
629
537
  log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
630
- return log
538
+ model_params = end[-1]
539
+ return log, model_params
631
540
 
632
541
  return _jax_wrapped_batched_rollout
633
542
 
634
543
  # ===========================================================================
635
- # error checks
544
+ # error checks and prints
636
545
  # ===========================================================================
637
546
 
638
547
  def print_jax(self) -> Dict[str, Any]:
@@ -640,19 +549,30 @@ class JaxRDDLCompiler:
640
549
  Jax compiled expressions from the RDDL file.
641
550
  '''
642
551
  subs = self.init_values
643
- params = self.model_params
552
+ init_params = self.model_params
644
553
  key = jax.random.PRNGKey(42)
645
- printed = {}
646
- printed['cpfs'] = {name: str(jax.make_jaxpr(expr)(subs, params, key))
647
- for (name, expr) in self.cpfs.items()}
648
- printed['reward'] = str(jax.make_jaxpr(self.reward)(subs, params, key))
649
- printed['invariants'] = [str(jax.make_jaxpr(expr)(subs, params, key))
650
- for expr in self.invariants]
651
- printed['preconditions'] = [str(jax.make_jaxpr(expr)(subs, params, key))
652
- for expr in self.preconditions]
653
- printed['terminations'] = [str(jax.make_jaxpr(expr)(subs, params, key))
654
- for expr in self.terminations]
554
+ printed = {
555
+ 'cpfs': {name: str(jax.make_jaxpr(expr)(subs, init_params, key))
556
+ for (name, expr) in self.cpfs.items()},
557
+ 'reward': str(jax.make_jaxpr(self.reward)(subs, init_params, key)),
558
+ 'invariants': [str(jax.make_jaxpr(expr)(subs, init_params, key))
559
+ for expr in self.invariants],
560
+ 'preconditions': [str(jax.make_jaxpr(expr)(subs, init_params, key))
561
+ for expr in self.preconditions],
562
+ 'terminations': [str(jax.make_jaxpr(expr)(subs, init_params, key))
563
+ for expr in self.terminations]
564
+ }
655
565
  return printed
566
+
567
+ def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
568
+ '''Returns a dictionary of additional information about model
569
+ parameters.'''
570
+ result = {}
571
+ for (id, value) in self.model_params.items():
572
+ expr_id = int(str(id).split('_')[0])
573
+ expr = self.traced.lookup(expr_id)
574
+ result[id] = {'id': expr_id, 'rddl_op': ' '.join(expr.etype), 'init_value': value}
575
+ return result
656
576
 
657
577
  @staticmethod
658
578
  def _check_valid_op(expr, valid_ops):
@@ -661,8 +581,7 @@ class JaxRDDLCompiler:
661
581
  valid_op_str = ','.join(valid_ops.keys())
662
582
  raise RDDLNotImplementedError(
663
583
  f'{etype} operator {op} is not supported: '
664
- f'must be in {valid_op_str}.\n' +
665
- print_stack_trace(expr))
584
+ f'must be in {valid_op_str}.\n' + print_stack_trace(expr))
666
585
 
667
586
  @staticmethod
668
587
  def _check_num_args(expr, required_args):
@@ -671,8 +590,7 @@ class JaxRDDLCompiler:
671
590
  etype, op = expr.etype
672
591
  raise RDDLInvalidNumberOfArgumentsError(
673
592
  f'{etype} operator {op} requires {required_args} arguments, '
674
- f'got {actual_args}.\n' +
675
- print_stack_trace(expr))
593
+ f'got {actual_args}.\n' + print_stack_trace(expr))
676
594
 
677
595
  ERROR_CODES = {
678
596
  'NORMAL': 0,
@@ -749,73 +667,34 @@ class JaxRDDLCompiler:
749
667
  messages = [JaxRDDLCompiler.INVERSE_ERROR_CODES[i] for i in codes]
750
668
  return messages
751
669
 
752
- # ===========================================================================
753
- # handling of auxiliary data (e.g. model tuning parameters)
754
- # ===========================================================================
755
-
756
- def _unwrap(self, op, expr_id, info):
757
- jax_op, name = op, None
758
- model_params, relaxed_list = info
759
- if isinstance(op, tuple):
760
- jax_op, param = op
761
- if param is not None:
762
- tags, values = param
763
- sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
764
- if isinstance(tags, tuple):
765
- name = sep.join(tags)
766
- else:
767
- name = str(tags)
768
- name = f'{name}{sep}{expr_id}'
769
- if name in model_params:
770
- raise RuntimeError(
771
- f'Internal error: model parameter {name} is already defined.')
772
- model_params[name] = (values, tags, expr_id, jax_op.__name__)
773
- relaxed_list.append((param, expr_id, jax_op.__name__))
774
- return jax_op, name
775
-
776
- def summarize_model_relaxations(self) -> str:
777
- '''Returns a string of information about model relaxations in the
778
- compiled model.'''
779
- occurence_by_type = {}
780
- for (_, expr_id, jax_op) in self.relaxations:
781
- etype = self.traced.lookup(expr_id).etype
782
- source = f'{etype[1]} ({etype[0]})'
783
- sub = f'{source:<30} --> {jax_op}'
784
- occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
785
- col = "{:<80} {:<10}\n"
786
- table = col.format('Substitution', 'Count')
787
- for (sub, occurs) in occurence_by_type.items():
788
- table += col.format(sub, occurs)
789
- return table
790
-
791
670
  # ===========================================================================
792
671
  # expression compilation
793
672
  # ===========================================================================
794
673
 
795
- def _jax(self, expr, info, dtype=None):
674
+ def _jax(self, expr, init_params, dtype=None):
796
675
  etype, _ = expr.etype
797
676
  if etype == 'constant':
798
- jax_expr = self._jax_constant(expr, info)
677
+ jax_expr = self._jax_constant(expr, init_params)
799
678
  elif etype == 'pvar':
800
- jax_expr = self._jax_pvar(expr, info)
679
+ jax_expr = self._jax_pvar(expr, init_params)
801
680
  elif etype == 'arithmetic':
802
- jax_expr = self._jax_arithmetic(expr, info)
681
+ jax_expr = self._jax_arithmetic(expr, init_params)
803
682
  elif etype == 'relational':
804
- jax_expr = self._jax_relational(expr, info)
683
+ jax_expr = self._jax_relational(expr, init_params)
805
684
  elif etype == 'boolean':
806
- jax_expr = self._jax_logical(expr, info)
685
+ jax_expr = self._jax_logical(expr, init_params)
807
686
  elif etype == 'aggregation':
808
- jax_expr = self._jax_aggregation(expr, info)
687
+ jax_expr = self._jax_aggregation(expr, init_params)
809
688
  elif etype == 'func':
810
- jax_expr = self._jax_functional(expr, info)
689
+ jax_expr = self._jax_functional(expr, init_params)
811
690
  elif etype == 'control':
812
- jax_expr = self._jax_control(expr, info)
691
+ jax_expr = self._jax_control(expr, init_params)
813
692
  elif etype == 'randomvar':
814
- jax_expr = self._jax_random(expr, info)
693
+ jax_expr = self._jax_random(expr, init_params)
815
694
  elif etype == 'randomvector':
816
- jax_expr = self._jax_random_vector(expr, info)
695
+ jax_expr = self._jax_random_vector(expr, init_params)
817
696
  elif etype == 'matrix':
818
- jax_expr = self._jax_matrix(expr, info)
697
+ jax_expr = self._jax_matrix(expr, init_params)
819
698
  else:
820
699
  raise RDDLNotImplementedError(
821
700
  f'Internal error: expression type {expr} is not supported.\n' +
@@ -831,13 +710,14 @@ class JaxRDDLCompiler:
831
710
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
832
711
 
833
712
  def _jax_wrapped_cast(x, params, key):
834
- val, key, err = jax_expr(x, params, key)
713
+ val, key, err, params = jax_expr(x, params, key)
835
714
  sample = jnp.asarray(val, dtype=dtype)
836
715
  invalid_cast = jnp.logical_and(
837
716
  jnp.logical_not(jnp.can_cast(val, dtype)),
838
- jnp.any(sample != val))
717
+ jnp.any(sample != val)
718
+ )
839
719
  err |= (invalid_cast * ERR)
840
- return sample, key, err
720
+ return sample, key, err, params
841
721
 
842
722
  return _jax_wrapped_cast
843
723
 
@@ -856,13 +736,13 @@ class JaxRDDLCompiler:
856
736
  # leaves
857
737
  # ===========================================================================
858
738
 
859
- def _jax_constant(self, expr, info):
739
+ def _jax_constant(self, expr, init_params):
860
740
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
861
741
  cached_value = self.traced.cached_sim_info(expr)
862
742
 
863
743
  def _jax_wrapped_constant(x, params, key):
864
744
  sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
865
- return sample, key, NORMAL
745
+ return sample, key, NORMAL, params
866
746
 
867
747
  return _jax_wrapped_constant
868
748
 
@@ -870,11 +750,11 @@ class JaxRDDLCompiler:
870
750
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
871
751
 
872
752
  def _jax_wrapped_pvar_slice(x, params, key):
873
- return _slice, key, NORMAL
753
+ return _slice, key, NORMAL, params
874
754
 
875
755
  return _jax_wrapped_pvar_slice
876
756
 
877
- def _jax_pvar(self, expr, info):
757
+ def _jax_pvar(self, expr, init_params):
878
758
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
879
759
  var, pvars = expr.args
880
760
  is_value, cached_info = self.traced.cached_sim_info(expr)
@@ -886,7 +766,7 @@ class JaxRDDLCompiler:
886
766
 
887
767
  def _jax_wrapped_object(x, params, key):
888
768
  sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
889
- return sample, key, NORMAL
769
+ return sample, key, NORMAL, params
890
770
 
891
771
  return _jax_wrapped_object
892
772
 
@@ -896,7 +776,7 @@ class JaxRDDLCompiler:
896
776
  def _jax_wrapped_pvar_scalar(x, params, key):
897
777
  value = x[var]
898
778
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
899
- return sample, key, NORMAL
779
+ return sample, key, NORMAL, params
900
780
 
901
781
  return _jax_wrapped_pvar_scalar
902
782
 
@@ -907,7 +787,7 @@ class JaxRDDLCompiler:
907
787
  # compile nested expressions
908
788
  if slices and op_code == RDDLObjectsTracer.NUMPY_OP_CODE.NESTED_SLICE:
909
789
 
910
- jax_nested_expr = [(self._jax(arg, info)
790
+ jax_nested_expr = [(self._jax(arg, init_params)
911
791
  if _slice is None
912
792
  else self._jax_pvar_slice(_slice))
913
793
  for (arg, _slice) in zip(pvars, slices)]
@@ -918,11 +798,11 @@ class JaxRDDLCompiler:
918
798
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
919
799
  new_slices = [None] * len(jax_nested_expr)
920
800
  for (i, jax_expr) in enumerate(jax_nested_expr):
921
- new_slices[i], key, err = jax_expr(x, params, key)
801
+ new_slices[i], key, err, params = jax_expr(x, params, key)
922
802
  error |= err
923
803
  new_slices = tuple(new_slices)
924
804
  sample = sample[new_slices]
925
- return sample, key, error
805
+ return sample, key, error, params
926
806
 
927
807
  return _jax_wrapped_pvar_tensor_nested
928
808
 
@@ -941,7 +821,7 @@ class JaxRDDLCompiler:
941
821
  sample = jnp.einsum(sample, *op_args)
942
822
  elif op_code == RDDLObjectsTracer.NUMPY_OP_CODE.TRANSPOSE:
943
823
  sample = jnp.transpose(sample, axes=op_args)
944
- return sample, key, NORMAL
824
+ return sample, key, NORMAL, params
945
825
 
946
826
  return _jax_wrapped_pvar_tensor_non_nested
947
827
 
@@ -949,46 +829,43 @@ class JaxRDDLCompiler:
949
829
  # mathematical
950
830
  # ===========================================================================
951
831
 
952
- def _jax_unary(self, jax_expr, jax_op, jax_param,
953
- at_least_int=False, check_dtype=None):
832
+ def _jax_unary(self, jax_expr, jax_op, at_least_int=False, check_dtype=None):
954
833
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
955
834
 
956
835
  def _jax_wrapped_unary_op(x, params, key):
957
- sample, key, err = jax_expr(x, params, key)
836
+ sample, key, err, params = jax_expr(x, params, key)
958
837
  if at_least_int:
959
838
  sample = self.ONE * sample
960
- param = params.get(jax_param, None)
961
- sample = jax_op(sample, param)
839
+ sample, params = jax_op(sample, params)
962
840
  if check_dtype is not None:
963
841
  invalid_cast = jnp.logical_not(jnp.can_cast(sample, check_dtype))
964
842
  err |= (invalid_cast * ERR)
965
- return sample, key, err
843
+ return sample, key, err, params
966
844
 
967
845
  return _jax_wrapped_unary_op
968
846
 
969
- def _jax_binary(self, jax_lhs, jax_rhs, jax_op, jax_param,
970
- at_least_int=False, check_dtype=None):
847
+ def _jax_binary(self, jax_lhs, jax_rhs, jax_op, at_least_int=False, check_dtype=None):
971
848
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
972
849
 
973
850
  def _jax_wrapped_binary_op(x, params, key):
974
- sample1, key, err1 = jax_lhs(x, params, key)
975
- sample2, key, err2 = jax_rhs(x, params, key)
851
+ sample1, key, err1, params = jax_lhs(x, params, key)
852
+ sample2, key, err2, params = jax_rhs(x, params, key)
976
853
  if at_least_int:
977
854
  sample1 = self.ONE * sample1
978
855
  sample2 = self.ONE * sample2
979
- param = params.get(jax_param, None)
980
- sample = jax_op(sample1, sample2, param)
856
+ sample, params = jax_op(sample1, sample2, params)
981
857
  err = err1 | err2
982
858
  if check_dtype is not None:
983
859
  invalid_cast = jnp.logical_not(jnp.logical_and(
984
860
  jnp.can_cast(sample1, check_dtype),
985
- jnp.can_cast(sample2, check_dtype)))
861
+ jnp.can_cast(sample2, check_dtype))
862
+ )
986
863
  err |= (invalid_cast * ERR)
987
- return sample, key, err
864
+ return sample, key, err, params
988
865
 
989
866
  return _jax_wrapped_binary_op
990
867
 
991
- def _jax_arithmetic(self, expr, info):
868
+ def _jax_arithmetic(self, expr, init_params):
992
869
  _, op = expr.etype
993
870
 
994
871
  # if expression is non-fluent, always use the exact operation
@@ -1005,22 +882,21 @@ class JaxRDDLCompiler:
1005
882
  n = len(args)
1006
883
  if n == 1 and op == '-':
1007
884
  arg, = args
1008
- jax_expr = self._jax(arg, info)
1009
- jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
1010
- return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
885
+ jax_expr = self._jax(arg, init_params)
886
+ jax_op = negative_op(expr.id, init_params)
887
+ return self._jax_unary(jax_expr, jax_op, at_least_int=True)
1011
888
 
1012
889
  elif n == 2 or (n >= 2 and op in {'*', '+'}):
1013
- jax_exprs = [self._jax(arg, info) for arg in args]
1014
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
890
+ jax_exprs = [self._jax(arg, init_params) for arg in args]
1015
891
  result = jax_exprs[0]
1016
- for jax_rhs in jax_exprs[1:]:
1017
- result = self._jax_binary(
1018
- result, jax_rhs, jax_op, jax_param, at_least_int=True)
892
+ for i, jax_rhs in enumerate(jax_exprs[1:]):
893
+ jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
894
+ result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
1019
895
  return result
1020
896
 
1021
897
  JaxRDDLCompiler._check_num_args(expr, 2)
1022
898
 
1023
- def _jax_relational(self, expr, info):
899
+ def _jax_relational(self, expr, init_params):
1024
900
  _, op = expr.etype
1025
901
 
1026
902
  # if expression is non-fluent, always use the exact operation
@@ -1029,17 +905,16 @@ class JaxRDDLCompiler:
1029
905
  else:
1030
906
  valid_ops = self.RELATIONAL_OPS
1031
907
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
1032
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1033
908
 
1034
909
  # recursively compile arguments
1035
910
  JaxRDDLCompiler._check_num_args(expr, 2)
1036
911
  lhs, rhs = expr.args
1037
- jax_lhs = self._jax(lhs, info)
1038
- jax_rhs = self._jax(rhs, info)
1039
- return self._jax_binary(
1040
- jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
912
+ jax_lhs = self._jax(lhs, init_params)
913
+ jax_rhs = self._jax(rhs, init_params)
914
+ jax_op = valid_ops[op](expr.id, init_params)
915
+ return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
1041
916
 
1042
- def _jax_logical(self, expr, info):
917
+ def _jax_logical(self, expr, init_params):
1043
918
  _, op = expr.etype
1044
919
 
1045
920
  # if expression is non-fluent, always use the exact operation
@@ -1056,22 +931,21 @@ class JaxRDDLCompiler:
1056
931
  n = len(args)
1057
932
  if n == 1 and op == '~':
1058
933
  arg, = args
1059
- jax_expr = self._jax(arg, info)
1060
- jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
1061
- return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
934
+ jax_expr = self._jax(arg, init_params)
935
+ jax_op = logical_not_op(expr.id, init_params)
936
+ return self._jax_unary(jax_expr, jax_op, check_dtype=bool)
1062
937
 
1063
938
  elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
1064
- jax_exprs = [self._jax(arg, info) for arg in args]
1065
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
939
+ jax_exprs = [self._jax(arg, init_params) for arg in args]
1066
940
  result = jax_exprs[0]
1067
- for jax_rhs in jax_exprs[1:]:
1068
- result = self._jax_binary(
1069
- result, jax_rhs, jax_op, jax_param, check_dtype=bool)
941
+ for i, jax_rhs in enumerate(jax_exprs[1:]):
942
+ jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
943
+ result = self._jax_binary(result, jax_rhs, jax_op, check_dtype=bool)
1070
944
  return result
1071
945
 
1072
946
  JaxRDDLCompiler._check_num_args(expr, 2)
1073
947
 
1074
- def _jax_aggregation(self, expr, info):
948
+ def _jax_aggregation(self, expr, init_params):
1075
949
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
1076
950
  _, op = expr.etype
1077
951
 
@@ -1081,28 +955,27 @@ class JaxRDDLCompiler:
1081
955
  else:
1082
956
  valid_ops = self.AGGREGATION_OPS
1083
957
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
1084
- jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
958
+ is_floating = op not in self.AGGREGATION_BOOL
1085
959
 
1086
960
  # recursively compile arguments
1087
- is_floating = op not in self.AGGREGATION_BOOL
1088
961
  * _, arg = expr.args
1089
962
  _, axes = self.traced.cached_sim_info(expr)
1090
- jax_expr = self._jax(arg, info)
963
+ jax_expr = self._jax(arg, init_params)
964
+ jax_op = valid_ops[op](expr.id, init_params)
1091
965
 
1092
966
  def _jax_wrapped_aggregation(x, params, key):
1093
- sample, key, err = jax_expr(x, params, key)
967
+ sample, key, err, params = jax_expr(x, params, key)
1094
968
  if is_floating:
1095
969
  sample = self.ONE * sample
1096
970
  else:
1097
971
  invalid_cast = jnp.logical_not(jnp.can_cast(sample, bool))
1098
972
  err |= (invalid_cast * ERR)
1099
- param = params.get(jax_param, None)
1100
- sample = jax_op(sample, axis=axes, param=param)
1101
- return sample, key, err
973
+ sample, params = jax_op(sample, axis=axes, params=params)
974
+ return sample, key, err, params
1102
975
 
1103
976
  return _jax_wrapped_aggregation
1104
977
 
1105
- def _jax_functional(self, expr, info):
978
+ def _jax_functional(self, expr, init_params):
1106
979
  _, op = expr.etype
1107
980
 
1108
981
  # if expression is non-fluent, always use the exact operation
@@ -1117,39 +990,36 @@ class JaxRDDLCompiler:
1117
990
  if op in unary_ops:
1118
991
  JaxRDDLCompiler._check_num_args(expr, 1)
1119
992
  arg, = expr.args
1120
- jax_expr = self._jax(arg, info)
1121
- jax_op, jax_param = self._unwrap(unary_ops[op], expr.id, info)
1122
- return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
993
+ jax_expr = self._jax(arg, init_params)
994
+ jax_op = unary_ops[op](expr.id, init_params)
995
+ return self._jax_unary(jax_expr, jax_op, at_least_int=True)
1123
996
 
1124
997
  elif op in binary_ops:
1125
998
  JaxRDDLCompiler._check_num_args(expr, 2)
1126
999
  lhs, rhs = expr.args
1127
- jax_lhs = self._jax(lhs, info)
1128
- jax_rhs = self._jax(rhs, info)
1129
- jax_op, jax_param = self._unwrap(binary_ops[op], expr.id, info)
1130
- return self._jax_binary(
1131
- jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
1000
+ jax_lhs = self._jax(lhs, init_params)
1001
+ jax_rhs = self._jax(rhs, init_params)
1002
+ jax_op = binary_ops[op](expr.id, init_params)
1003
+ return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
1132
1004
 
1133
1005
  raise RDDLNotImplementedError(
1134
- f'Function {op} is not supported.\n' +
1135
- print_stack_trace(expr))
1006
+ f'Function {op} is not supported.\n' + print_stack_trace(expr))
1136
1007
 
1137
1008
  # ===========================================================================
1138
1009
  # control flow
1139
1010
  # ===========================================================================
1140
1011
 
1141
- def _jax_control(self, expr, info):
1012
+ def _jax_control(self, expr, init_params):
1142
1013
  _, op = expr.etype
1143
1014
  if op == 'if':
1144
- return self._jax_if(expr, info)
1015
+ return self._jax_if(expr, init_params)
1145
1016
  elif op == 'switch':
1146
- return self._jax_switch(expr, info)
1017
+ return self._jax_switch(expr, init_params)
1147
1018
 
1148
1019
  raise RDDLNotImplementedError(
1149
- f'Control operator {op} is not supported.\n' +
1150
- print_stack_trace(expr))
1020
+ f'Control operator {op} is not supported.\n' + print_stack_trace(expr))
1151
1021
 
1152
- def _jax_if(self, expr, info):
1022
+ def _jax_if(self, expr, init_params):
1153
1023
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
1154
1024
  JaxRDDLCompiler._check_num_args(expr, 3)
1155
1025
  pred, if_true, if_false = expr.args
@@ -1159,27 +1029,26 @@ class JaxRDDLCompiler:
1159
1029
  if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
1160
1030
  else:
1161
1031
  if_op = self.IF_HELPER
1162
- jax_if, jax_param = self._unwrap(if_op, expr.id, info)
1032
+ jax_op = if_op(expr.id, init_params)
1163
1033
 
1164
1034
  # recursively compile arguments
1165
- jax_pred = self._jax(pred, info)
1166
- jax_true = self._jax(if_true, info)
1167
- jax_false = self._jax(if_false, info)
1035
+ jax_pred = self._jax(pred, init_params)
1036
+ jax_true = self._jax(if_true, init_params)
1037
+ jax_false = self._jax(if_false, init_params)
1168
1038
 
1169
1039
  def _jax_wrapped_if_then_else(x, params, key):
1170
- sample1, key, err1 = jax_pred(x, params, key)
1171
- sample2, key, err2 = jax_true(x, params, key)
1172
- sample3, key, err3 = jax_false(x, params, key)
1173
- param = params.get(jax_param, None)
1174
- sample = jax_if(sample1, sample2, sample3, param)
1040
+ sample1, key, err1, params = jax_pred(x, params, key)
1041
+ sample2, key, err2, params = jax_true(x, params, key)
1042
+ sample3, key, err3, params = jax_false(x, params, key)
1043
+ sample, params = jax_op(sample1, sample2, sample3, params)
1175
1044
  err = err1 | err2 | err3
1176
1045
  invalid_cast = jnp.logical_not(jnp.can_cast(sample1, bool))
1177
1046
  err |= (invalid_cast * ERR)
1178
- return sample, key, err
1047
+ return sample, key, err, params
1179
1048
 
1180
1049
  return _jax_wrapped_if_then_else
1181
1050
 
1182
- def _jax_switch(self, expr, info):
1051
+ def _jax_switch(self, expr, init_params):
1183
1052
  pred, *_ = expr.args
1184
1053
 
1185
1054
  # if predicate is non-fluent, always use the exact operation
@@ -1188,34 +1057,33 @@ class JaxRDDLCompiler:
1188
1057
  switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1189
1058
  else:
1190
1059
  switch_op = self.SWITCH_HELPER
1191
- jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1060
+ jax_op = switch_op(expr.id, init_params)
1192
1061
 
1193
1062
  # recursively compile predicate
1194
- jax_pred = self._jax(pred, info)
1063
+ jax_pred = self._jax(pred, init_params)
1195
1064
 
1196
1065
  # recursively compile cases
1197
1066
  cases, default = self.traced.cached_sim_info(expr)
1198
- jax_default = None if default is None else self._jax(default, info)
1199
- jax_cases = [(jax_default if _case is None else self._jax(_case, info))
1067
+ jax_default = None if default is None else self._jax(default, init_params)
1068
+ jax_cases = [(jax_default if _case is None else self._jax(_case, init_params))
1200
1069
  for _case in cases]
1201
1070
 
1202
1071
  def _jax_wrapped_switch(x, params, key):
1203
1072
 
1204
1073
  # sample predicate
1205
- sample_pred, key, err = jax_pred(x, params, key)
1074
+ sample_pred, key, err, params = jax_pred(x, params, key)
1206
1075
 
1207
1076
  # sample cases
1208
1077
  sample_cases = [None] * len(jax_cases)
1209
1078
  for (i, jax_case) in enumerate(jax_cases):
1210
- sample_cases[i], key, err_case = jax_case(x, params, key)
1079
+ sample_cases[i], key, err_case, params = jax_case(x, params, key)
1211
1080
  err |= err_case
1212
1081
  sample_cases = jnp.asarray(
1213
1082
  sample_cases, dtype=self._fix_dtype(sample_cases))
1214
1083
 
1215
1084
  # predicate (enum) is an integer - use it to extract from case array
1216
- param = params.get(jax_param, None)
1217
- sample = jax_switch(sample_pred, sample_cases, param)
1218
- return sample, key, err
1085
+ sample, params = jax_op(sample_pred, sample_cases, params)
1086
+ return sample, key, err, params
1219
1087
 
1220
1088
  return _jax_wrapped_switch
1221
1089
 
@@ -1251,169 +1119,169 @@ class JaxRDDLCompiler:
1251
1119
  # Geometric: (implement safe floor)
1252
1120
  # Student: (no reparameterization)
1253
1121
 
1254
- def _jax_random(self, expr, info):
1122
+ def _jax_random(self, expr, init_params):
1255
1123
  _, name = expr.etype
1256
1124
  if name == 'KronDelta':
1257
- return self._jax_kron(expr, info)
1125
+ return self._jax_kron(expr, init_params)
1258
1126
  elif name == 'DiracDelta':
1259
- return self._jax_dirac(expr, info)
1127
+ return self._jax_dirac(expr, init_params)
1260
1128
  elif name == 'Uniform':
1261
- return self._jax_uniform(expr, info)
1129
+ return self._jax_uniform(expr, init_params)
1262
1130
  elif name == 'Bernoulli':
1263
- return self._jax_bernoulli(expr, info)
1131
+ return self._jax_bernoulli(expr, init_params)
1264
1132
  elif name == 'Normal':
1265
- return self._jax_normal(expr, info)
1133
+ return self._jax_normal(expr, init_params)
1266
1134
  elif name == 'Poisson':
1267
- return self._jax_poisson(expr, info)
1135
+ return self._jax_poisson(expr, init_params)
1268
1136
  elif name == 'Exponential':
1269
- return self._jax_exponential(expr, info)
1137
+ return self._jax_exponential(expr, init_params)
1270
1138
  elif name == 'Weibull':
1271
- return self._jax_weibull(expr, info)
1139
+ return self._jax_weibull(expr, init_params)
1272
1140
  elif name == 'Gamma':
1273
- return self._jax_gamma(expr, info)
1141
+ return self._jax_gamma(expr, init_params)
1274
1142
  elif name == 'Binomial':
1275
- return self._jax_binomial(expr, info)
1143
+ return self._jax_binomial(expr, init_params)
1276
1144
  elif name == 'NegativeBinomial':
1277
- return self._jax_negative_binomial(expr, info)
1145
+ return self._jax_negative_binomial(expr, init_params)
1278
1146
  elif name == 'Beta':
1279
- return self._jax_beta(expr, info)
1147
+ return self._jax_beta(expr, init_params)
1280
1148
  elif name == 'Geometric':
1281
- return self._jax_geometric(expr, info)
1149
+ return self._jax_geometric(expr, init_params)
1282
1150
  elif name == 'Pareto':
1283
- return self._jax_pareto(expr, info)
1151
+ return self._jax_pareto(expr, init_params)
1284
1152
  elif name == 'Student':
1285
- return self._jax_student(expr, info)
1153
+ return self._jax_student(expr, init_params)
1286
1154
  elif name == 'Gumbel':
1287
- return self._jax_gumbel(expr, info)
1155
+ return self._jax_gumbel(expr, init_params)
1288
1156
  elif name == 'Laplace':
1289
- return self._jax_laplace(expr, info)
1157
+ return self._jax_laplace(expr, init_params)
1290
1158
  elif name == 'Cauchy':
1291
- return self._jax_cauchy(expr, info)
1159
+ return self._jax_cauchy(expr, init_params)
1292
1160
  elif name == 'Gompertz':
1293
- return self._jax_gompertz(expr, info)
1161
+ return self._jax_gompertz(expr, init_params)
1294
1162
  elif name == 'ChiSquare':
1295
- return self._jax_chisquare(expr, info)
1163
+ return self._jax_chisquare(expr, init_params)
1296
1164
  elif name == 'Kumaraswamy':
1297
- return self._jax_kumaraswamy(expr, info)
1165
+ return self._jax_kumaraswamy(expr, init_params)
1298
1166
  elif name == 'Discrete':
1299
- return self._jax_discrete(expr, info, unnorm=False)
1167
+ return self._jax_discrete(expr, init_params, unnorm=False)
1300
1168
  elif name == 'UnnormDiscrete':
1301
- return self._jax_discrete(expr, info, unnorm=True)
1169
+ return self._jax_discrete(expr, init_params, unnorm=True)
1302
1170
  elif name == 'Discrete(p)':
1303
- return self._jax_discrete_pvar(expr, info, unnorm=False)
1171
+ return self._jax_discrete_pvar(expr, init_params, unnorm=False)
1304
1172
  elif name == 'UnnormDiscrete(p)':
1305
- return self._jax_discrete_pvar(expr, info, unnorm=True)
1173
+ return self._jax_discrete_pvar(expr, init_params, unnorm=True)
1306
1174
  else:
1307
1175
  raise RDDLNotImplementedError(
1308
1176
  f'Distribution {name} is not supported.\n' +
1309
1177
  print_stack_trace(expr))
1310
1178
 
1311
- def _jax_kron(self, expr, info):
1179
+ def _jax_kron(self, expr, init_params):
1312
1180
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
1313
1181
  JaxRDDLCompiler._check_num_args(expr, 1)
1314
1182
  arg, = expr.args
1315
- arg = self._jax(arg, info)
1183
+ arg = self._jax(arg, init_params)
1316
1184
 
1317
1185
  # just check that the sample can be cast to int
1318
1186
  def _jax_wrapped_distribution_kron(x, params, key):
1319
- sample, key, err = arg(x, params, key)
1187
+ sample, key, err, params = arg(x, params, key)
1320
1188
  invalid_cast = jnp.logical_not(jnp.can_cast(sample, self.INT))
1321
1189
  err |= (invalid_cast * ERR)
1322
- return sample, key, err
1190
+ return sample, key, err, params
1323
1191
 
1324
1192
  return _jax_wrapped_distribution_kron
1325
1193
 
1326
- def _jax_dirac(self, expr, info):
1194
+ def _jax_dirac(self, expr, init_params):
1327
1195
  JaxRDDLCompiler._check_num_args(expr, 1)
1328
1196
  arg, = expr.args
1329
- arg = self._jax(arg, info, dtype=self.REAL)
1197
+ arg = self._jax(arg, init_params, dtype=self.REAL)
1330
1198
  return arg
1331
1199
 
1332
- def _jax_uniform(self, expr, info):
1200
+ def _jax_uniform(self, expr, init_params):
1333
1201
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_UNIFORM']
1334
1202
  JaxRDDLCompiler._check_num_args(expr, 2)
1335
1203
 
1336
1204
  arg_lb, arg_ub = expr.args
1337
- jax_lb = self._jax(arg_lb, info)
1338
- jax_ub = self._jax(arg_ub, info)
1205
+ jax_lb = self._jax(arg_lb, init_params)
1206
+ jax_ub = self._jax(arg_ub, init_params)
1339
1207
 
1340
1208
  # reparameterization trick U(a, b) = a + (b - a) * U(0, 1)
1341
1209
  def _jax_wrapped_distribution_uniform(x, params, key):
1342
- lb, key, err1 = jax_lb(x, params, key)
1343
- ub, key, err2 = jax_ub(x, params, key)
1210
+ lb, key, err1, params = jax_lb(x, params, key)
1211
+ ub, key, err2, params = jax_ub(x, params, key)
1344
1212
  key, subkey = random.split(key)
1345
1213
  U = random.uniform(key=subkey, shape=jnp.shape(lb), dtype=self.REAL)
1346
1214
  sample = lb + (ub - lb) * U
1347
1215
  out_of_bounds = jnp.logical_not(jnp.all(lb <= ub))
1348
1216
  err = err1 | err2 | (out_of_bounds * ERR)
1349
- return sample, key, err
1217
+ return sample, key, err, params
1350
1218
 
1351
1219
  return _jax_wrapped_distribution_uniform
1352
1220
 
1353
- def _jax_normal(self, expr, info):
1221
+ def _jax_normal(self, expr, init_params):
1354
1222
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NORMAL']
1355
1223
  JaxRDDLCompiler._check_num_args(expr, 2)
1356
1224
 
1357
1225
  arg_mean, arg_var = expr.args
1358
- jax_mean = self._jax(arg_mean, info)
1359
- jax_var = self._jax(arg_var, info)
1226
+ jax_mean = self._jax(arg_mean, init_params)
1227
+ jax_var = self._jax(arg_var, init_params)
1360
1228
 
1361
1229
  # reparameterization trick N(m, s^2) = m + s * N(0, 1)
1362
1230
  def _jax_wrapped_distribution_normal(x, params, key):
1363
- mean, key, err1 = jax_mean(x, params, key)
1364
- var, key, err2 = jax_var(x, params, key)
1231
+ mean, key, err1, params = jax_mean(x, params, key)
1232
+ var, key, err2, params = jax_var(x, params, key)
1365
1233
  std = jnp.sqrt(var)
1366
1234
  key, subkey = random.split(key)
1367
1235
  Z = random.normal(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1368
1236
  sample = mean + std * Z
1369
1237
  out_of_bounds = jnp.logical_not(jnp.all(var >= 0))
1370
1238
  err = err1 | err2 | (out_of_bounds * ERR)
1371
- return sample, key, err
1239
+ return sample, key, err, params
1372
1240
 
1373
1241
  return _jax_wrapped_distribution_normal
1374
1242
 
1375
- def _jax_exponential(self, expr, info):
1243
+ def _jax_exponential(self, expr, init_params):
1376
1244
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_EXPONENTIAL']
1377
1245
  JaxRDDLCompiler._check_num_args(expr, 1)
1378
1246
 
1379
1247
  arg_scale, = expr.args
1380
- jax_scale = self._jax(arg_scale, info)
1248
+ jax_scale = self._jax(arg_scale, init_params)
1381
1249
 
1382
1250
  # reparameterization trick Exp(s) = s * Exp(1)
1383
1251
  def _jax_wrapped_distribution_exp(x, params, key):
1384
- scale, key, err = jax_scale(x, params, key)
1252
+ scale, key, err, params = jax_scale(x, params, key)
1385
1253
  key, subkey = random.split(key)
1386
1254
  Exp1 = random.exponential(
1387
1255
  key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1388
1256
  sample = scale * Exp1
1389
1257
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1390
1258
  err |= (out_of_bounds * ERR)
1391
- return sample, key, err
1259
+ return sample, key, err, params
1392
1260
 
1393
1261
  return _jax_wrapped_distribution_exp
1394
1262
 
1395
- def _jax_weibull(self, expr, info):
1263
+ def _jax_weibull(self, expr, init_params):
1396
1264
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_WEIBULL']
1397
1265
  JaxRDDLCompiler._check_num_args(expr, 2)
1398
1266
 
1399
1267
  arg_shape, arg_scale = expr.args
1400
- jax_shape = self._jax(arg_shape, info)
1401
- jax_scale = self._jax(arg_scale, info)
1268
+ jax_shape = self._jax(arg_shape, init_params)
1269
+ jax_scale = self._jax(arg_scale, init_params)
1402
1270
 
1403
1271
  # reparameterization trick W(s, r) = r * (-ln(1 - U(0, 1))) ** (1 / s)
1404
1272
  def _jax_wrapped_distribution_weibull(x, params, key):
1405
- shape, key, err1 = jax_shape(x, params, key)
1406
- scale, key, err2 = jax_scale(x, params, key)
1273
+ shape, key, err1, params = jax_shape(x, params, key)
1274
+ scale, key, err2, params = jax_scale(x, params, key)
1407
1275
  key, subkey = random.split(key)
1408
1276
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1409
1277
  sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
1410
1278
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1411
1279
  err = err1 | err2 | (out_of_bounds * ERR)
1412
- return sample, key, err
1280
+ return sample, key, err, params
1413
1281
 
1414
1282
  return _jax_wrapped_distribution_weibull
1415
1283
 
1416
- def _jax_bernoulli(self, expr, info):
1284
+ def _jax_bernoulli(self, expr, init_params):
1417
1285
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1418
1286
  JaxRDDLCompiler._check_num_args(expr, 1)
1419
1287
  arg_prob, = expr.args
@@ -1423,23 +1291,22 @@ class JaxRDDLCompiler:
1423
1291
  bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
1424
1292
  else:
1425
1293
  bern_op = self.BERNOULLI_HELPER
1426
- jax_bern, jax_param = self._unwrap(bern_op, expr.id, info)
1294
+ jax_op = bern_op(expr.id, init_params)
1427
1295
 
1428
1296
  # recursively compile arguments
1429
- jax_prob = self._jax(arg_prob, info)
1297
+ jax_prob = self._jax(arg_prob, init_params)
1430
1298
 
1431
1299
  def _jax_wrapped_distribution_bernoulli(x, params, key):
1432
- prob, key, err = jax_prob(x, params, key)
1300
+ prob, key, err, params = jax_prob(x, params, key)
1433
1301
  key, subkey = random.split(key)
1434
- param = params.get(jax_param, None)
1435
- sample = jax_bern(subkey, prob, param)
1302
+ sample, params = jax_op(subkey, prob, params)
1436
1303
  out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1437
1304
  err |= (out_of_bounds * ERR)
1438
- return sample, key, err
1305
+ return sample, key, err, params
1439
1306
 
1440
1307
  return _jax_wrapped_distribution_bernoulli
1441
1308
 
1442
- def _jax_poisson(self, expr, info):
1309
+ def _jax_poisson(self, expr, init_params):
1443
1310
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
1444
1311
  JaxRDDLCompiler._check_num_args(expr, 1)
1445
1312
  arg_rate, = expr.args
@@ -1449,57 +1316,57 @@ class JaxRDDLCompiler:
1449
1316
  poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1450
1317
  else:
1451
1318
  poisson_op = self.POISSON_HELPER
1452
- jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
1319
+ jax_op = poisson_op(expr.id, init_params)
1453
1320
 
1454
1321
  # recursively compile arguments
1455
- jax_rate = self._jax(arg_rate, info)
1322
+ jax_rate = self._jax(arg_rate, init_params)
1456
1323
 
1457
1324
  # uses the implicit JAX subroutine
1458
1325
  def _jax_wrapped_distribution_poisson(x, params, key):
1459
- rate, key, err = jax_rate(x, params, key)
1326
+ rate, key, err, params = jax_rate(x, params, key)
1460
1327
  key, subkey = random.split(key)
1461
- param = params.get(jax_param, None)
1462
- sample = jax_poisson(subkey, rate, param).astype(self.INT)
1328
+ sample, params = jax_op(subkey, rate, params)
1329
+ sample = sample.astype(self.INT)
1463
1330
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1464
1331
  err |= (out_of_bounds * ERR)
1465
- return sample, key, err
1332
+ return sample, key, err, params
1466
1333
 
1467
1334
  return _jax_wrapped_distribution_poisson
1468
1335
 
1469
- def _jax_gamma(self, expr, info):
1336
+ def _jax_gamma(self, expr, init_params):
1470
1337
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GAMMA']
1471
1338
  JaxRDDLCompiler._check_num_args(expr, 2)
1472
1339
 
1473
1340
  arg_shape, arg_scale = expr.args
1474
- jax_shape = self._jax(arg_shape, info)
1475
- jax_scale = self._jax(arg_scale, info)
1341
+ jax_shape = self._jax(arg_shape, init_params)
1342
+ jax_scale = self._jax(arg_scale, init_params)
1476
1343
 
1477
1344
  # partial reparameterization trick Gamma(s, r) = r * Gamma(s, 1)
1478
1345
  # uses the implicit JAX subroutine for Gamma(s, 1)
1479
1346
  def _jax_wrapped_distribution_gamma(x, params, key):
1480
- shape, key, err1 = jax_shape(x, params, key)
1481
- scale, key, err2 = jax_scale(x, params, key)
1347
+ shape, key, err1, params = jax_shape(x, params, key)
1348
+ scale, key, err2, params = jax_scale(x, params, key)
1482
1349
  key, subkey = random.split(key)
1483
1350
  Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
1484
1351
  sample = scale * Gamma
1485
1352
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1486
1353
  err = err1 | err2 | (out_of_bounds * ERR)
1487
- return sample, key, err
1354
+ return sample, key, err, params
1488
1355
 
1489
1356
  return _jax_wrapped_distribution_gamma
1490
1357
 
1491
- def _jax_binomial(self, expr, info):
1358
+ def _jax_binomial(self, expr, init_params):
1492
1359
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1493
1360
  JaxRDDLCompiler._check_num_args(expr, 2)
1494
1361
 
1495
1362
  arg_trials, arg_prob = expr.args
1496
- jax_trials = self._jax(arg_trials, info)
1497
- jax_prob = self._jax(arg_prob, info)
1363
+ jax_trials = self._jax(arg_trials, init_params)
1364
+ jax_prob = self._jax(arg_prob, init_params)
1498
1365
 
1499
1366
  # uses the JAX substrate of tensorflow-probability
1500
1367
  def _jax_wrapped_distribution_binomial(x, params, key):
1501
- trials, key, err2 = jax_trials(x, params, key)
1502
- prob, key, err1 = jax_prob(x, params, key)
1368
+ trials, key, err2, params = jax_trials(x, params, key)
1369
+ prob, key, err1, params = jax_prob(x, params, key)
1503
1370
  trials = jnp.asarray(trials, dtype=self.REAL)
1504
1371
  prob = jnp.asarray(prob, dtype=self.REAL)
1505
1372
  key, subkey = random.split(key)
@@ -1508,55 +1375,56 @@ class JaxRDDLCompiler:
1508
1375
  out_of_bounds = jnp.logical_not(jnp.all(
1509
1376
  (prob >= 0) & (prob <= 1) & (trials >= 0)))
1510
1377
  err = err1 | err2 | (out_of_bounds * ERR)
1511
- return sample, key, err
1378
+ return sample, key, err, params
1512
1379
 
1513
1380
  return _jax_wrapped_distribution_binomial
1514
1381
 
1515
- def _jax_negative_binomial(self, expr, info):
1382
+ def _jax_negative_binomial(self, expr, init_params):
1516
1383
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1517
1384
  JaxRDDLCompiler._check_num_args(expr, 2)
1518
1385
 
1519
1386
  arg_trials, arg_prob = expr.args
1520
- jax_trials = self._jax(arg_trials, info)
1521
- jax_prob = self._jax(arg_prob, info)
1387
+ jax_trials = self._jax(arg_trials, init_params)
1388
+ jax_prob = self._jax(arg_prob, init_params)
1522
1389
 
1523
1390
  # uses the JAX substrate of tensorflow-probability
1524
1391
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1525
- trials, key, err2 = jax_trials(x, params, key)
1526
- prob, key, err1 = jax_prob(x, params, key)
1392
+ trials, key, err2, params = jax_trials(x, params, key)
1393
+ prob, key, err1, params = jax_prob(x, params, key)
1527
1394
  trials = jnp.asarray(trials, dtype=self.REAL)
1528
1395
  prob = jnp.asarray(prob, dtype=self.REAL)
1529
1396
  key, subkey = random.split(key)
1530
1397
  dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1531
1398
  sample = dist.sample(seed=subkey).astype(self.INT)
1532
1399
  out_of_bounds = jnp.logical_not(jnp.all(
1533
- (prob >= 0) & (prob <= 1) & (trials > 0)))
1400
+ (prob >= 0) & (prob <= 1) & (trials > 0))
1401
+ )
1534
1402
  err = err1 | err2 | (out_of_bounds * ERR)
1535
- return sample, key, err
1403
+ return sample, key, err, params
1536
1404
 
1537
1405
  return _jax_wrapped_distribution_negative_binomial
1538
1406
 
1539
- def _jax_beta(self, expr, info):
1407
+ def _jax_beta(self, expr, init_params):
1540
1408
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BETA']
1541
1409
  JaxRDDLCompiler._check_num_args(expr, 2)
1542
1410
 
1543
1411
  arg_shape, arg_rate = expr.args
1544
- jax_shape = self._jax(arg_shape, info)
1545
- jax_rate = self._jax(arg_rate, info)
1412
+ jax_shape = self._jax(arg_shape, init_params)
1413
+ jax_rate = self._jax(arg_rate, init_params)
1546
1414
 
1547
1415
  # uses the implicit JAX subroutine
1548
1416
  def _jax_wrapped_distribution_beta(x, params, key):
1549
- shape, key, err1 = jax_shape(x, params, key)
1550
- rate, key, err2 = jax_rate(x, params, key)
1417
+ shape, key, err1, params = jax_shape(x, params, key)
1418
+ rate, key, err2, params = jax_rate(x, params, key)
1551
1419
  key, subkey = random.split(key)
1552
1420
  sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
1553
1421
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
1554
1422
  err = err1 | err2 | (out_of_bounds * ERR)
1555
- return sample, key, err
1423
+ return sample, key, err, params
1556
1424
 
1557
1425
  return _jax_wrapped_distribution_beta
1558
1426
 
1559
- def _jax_geometric(self, expr, info):
1427
+ def _jax_geometric(self, expr, init_params):
1560
1428
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1561
1429
  JaxRDDLCompiler._check_num_args(expr, 1)
1562
1430
  arg_prob, = expr.args
@@ -1566,187 +1434,187 @@ class JaxRDDLCompiler:
1566
1434
  geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1567
1435
  else:
1568
1436
  geom_op = self.GEOMETRIC_HELPER
1569
- jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
1437
+ jax_op = geom_op(expr.id, init_params)
1570
1438
 
1571
1439
  # recursively compile arguments
1572
- jax_prob = self._jax(arg_prob, info)
1440
+ jax_prob = self._jax(arg_prob, init_params)
1573
1441
 
1574
1442
  def _jax_wrapped_distribution_geometric(x, params, key):
1575
- prob, key, err = jax_prob(x, params, key)
1443
+ prob, key, err, params = jax_prob(x, params, key)
1576
1444
  key, subkey = random.split(key)
1577
- param = params.get(jax_param, None)
1578
- sample = jax_geom(subkey, prob, param).astype(self.INT)
1445
+ sample, params = jax_op(subkey, prob, params)
1446
+ sample = sample.astype(self.INT)
1579
1447
  out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1580
1448
  err |= (out_of_bounds * ERR)
1581
- return sample, key, err
1449
+ return sample, key, err, params
1582
1450
 
1583
1451
  return _jax_wrapped_distribution_geometric
1584
1452
 
1585
- def _jax_pareto(self, expr, info):
1453
+ def _jax_pareto(self, expr, init_params):
1586
1454
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_PARETO']
1587
1455
  JaxRDDLCompiler._check_num_args(expr, 2)
1588
1456
 
1589
1457
  arg_shape, arg_scale = expr.args
1590
- jax_shape = self._jax(arg_shape, info)
1591
- jax_scale = self._jax(arg_scale, info)
1458
+ jax_shape = self._jax(arg_shape, init_params)
1459
+ jax_scale = self._jax(arg_scale, init_params)
1592
1460
 
1593
1461
  # partial reparameterization trick Pareto(s, r) = r * Pareto(s, 1)
1594
1462
  # uses the implicit JAX subroutine for Pareto(s, 1)
1595
1463
  def _jax_wrapped_distribution_pareto(x, params, key):
1596
- shape, key, err1 = jax_shape(x, params, key)
1597
- scale, key, err2 = jax_scale(x, params, key)
1464
+ shape, key, err1, params = jax_shape(x, params, key)
1465
+ scale, key, err2, params = jax_scale(x, params, key)
1598
1466
  key, subkey = random.split(key)
1599
1467
  sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
1600
1468
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1601
1469
  err = err1 | err2 | (out_of_bounds * ERR)
1602
- return sample, key, err
1470
+ return sample, key, err, params
1603
1471
 
1604
1472
  return _jax_wrapped_distribution_pareto
1605
1473
 
1606
- def _jax_student(self, expr, info):
1474
+ def _jax_student(self, expr, init_params):
1607
1475
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_STUDENT']
1608
1476
  JaxRDDLCompiler._check_num_args(expr, 1)
1609
1477
 
1610
1478
  arg_df, = expr.args
1611
- jax_df = self._jax(arg_df, info)
1479
+ jax_df = self._jax(arg_df, init_params)
1612
1480
 
1613
1481
  # uses the implicit JAX subroutine for student(df)
1614
1482
  def _jax_wrapped_distribution_t(x, params, key):
1615
- df, key, err = jax_df(x, params, key)
1483
+ df, key, err, params = jax_df(x, params, key)
1616
1484
  key, subkey = random.split(key)
1617
1485
  sample = random.t(
1618
1486
  key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1619
1487
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1620
1488
  err |= (out_of_bounds * ERR)
1621
- return sample, key, err
1489
+ return sample, key, err, params
1622
1490
 
1623
1491
  return _jax_wrapped_distribution_t
1624
1492
 
1625
- def _jax_gumbel(self, expr, info):
1493
+ def _jax_gumbel(self, expr, init_params):
1626
1494
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GUMBEL']
1627
1495
  JaxRDDLCompiler._check_num_args(expr, 2)
1628
1496
 
1629
1497
  arg_mean, arg_scale = expr.args
1630
- jax_mean = self._jax(arg_mean, info)
1631
- jax_scale = self._jax(arg_scale, info)
1498
+ jax_mean = self._jax(arg_mean, init_params)
1499
+ jax_scale = self._jax(arg_scale, init_params)
1632
1500
 
1633
1501
  # reparameterization trick Gumbel(m, s) = m + s * Gumbel(0, 1)
1634
1502
  def _jax_wrapped_distribution_gumbel(x, params, key):
1635
- mean, key, err1 = jax_mean(x, params, key)
1636
- scale, key, err2 = jax_scale(x, params, key)
1503
+ mean, key, err1, params = jax_mean(x, params, key)
1504
+ scale, key, err2, params = jax_scale(x, params, key)
1637
1505
  key, subkey = random.split(key)
1638
1506
  Gumbel01 = random.gumbel(
1639
1507
  key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1640
1508
  sample = mean + scale * Gumbel01
1641
1509
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1642
1510
  err = err1 | err2 | (out_of_bounds * ERR)
1643
- return sample, key, err
1511
+ return sample, key, err, params
1644
1512
 
1645
1513
  return _jax_wrapped_distribution_gumbel
1646
1514
 
1647
- def _jax_laplace(self, expr, info):
1515
+ def _jax_laplace(self, expr, init_params):
1648
1516
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_LAPLACE']
1649
1517
  JaxRDDLCompiler._check_num_args(expr, 2)
1650
1518
 
1651
1519
  arg_mean, arg_scale = expr.args
1652
- jax_mean = self._jax(arg_mean, info)
1653
- jax_scale = self._jax(arg_scale, info)
1520
+ jax_mean = self._jax(arg_mean, init_params)
1521
+ jax_scale = self._jax(arg_scale, init_params)
1654
1522
 
1655
1523
  # reparameterization trick Laplace(m, s) = m + s * Laplace(0, 1)
1656
1524
  def _jax_wrapped_distribution_laplace(x, params, key):
1657
- mean, key, err1 = jax_mean(x, params, key)
1658
- scale, key, err2 = jax_scale(x, params, key)
1525
+ mean, key, err1, params = jax_mean(x, params, key)
1526
+ scale, key, err2, params = jax_scale(x, params, key)
1659
1527
  key, subkey = random.split(key)
1660
1528
  Laplace01 = random.laplace(
1661
1529
  key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1662
1530
  sample = mean + scale * Laplace01
1663
1531
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1664
1532
  err = err1 | err2 | (out_of_bounds * ERR)
1665
- return sample, key, err
1533
+ return sample, key, err, params
1666
1534
 
1667
1535
  return _jax_wrapped_distribution_laplace
1668
1536
 
1669
- def _jax_cauchy(self, expr, info):
1537
+ def _jax_cauchy(self, expr, init_params):
1670
1538
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CAUCHY']
1671
1539
  JaxRDDLCompiler._check_num_args(expr, 2)
1672
1540
 
1673
1541
  arg_mean, arg_scale = expr.args
1674
- jax_mean = self._jax(arg_mean, info)
1675
- jax_scale = self._jax(arg_scale, info)
1542
+ jax_mean = self._jax(arg_mean, init_params)
1543
+ jax_scale = self._jax(arg_scale, init_params)
1676
1544
 
1677
1545
  # reparameterization trick Cauchy(m, s) = m + s * Cauchy(0, 1)
1678
1546
  def _jax_wrapped_distribution_cauchy(x, params, key):
1679
- mean, key, err1 = jax_mean(x, params, key)
1680
- scale, key, err2 = jax_scale(x, params, key)
1547
+ mean, key, err1, params = jax_mean(x, params, key)
1548
+ scale, key, err2, params = jax_scale(x, params, key)
1681
1549
  key, subkey = random.split(key)
1682
1550
  Cauchy01 = random.cauchy(
1683
1551
  key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1684
1552
  sample = mean + scale * Cauchy01
1685
1553
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1686
1554
  err = err1 | err2 | (out_of_bounds * ERR)
1687
- return sample, key, err
1555
+ return sample, key, err, params
1688
1556
 
1689
1557
  return _jax_wrapped_distribution_cauchy
1690
1558
 
1691
- def _jax_gompertz(self, expr, info):
1559
+ def _jax_gompertz(self, expr, init_params):
1692
1560
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GOMPERTZ']
1693
1561
  JaxRDDLCompiler._check_num_args(expr, 2)
1694
1562
 
1695
1563
  arg_shape, arg_scale = expr.args
1696
- jax_shape = self._jax(arg_shape, info)
1697
- jax_scale = self._jax(arg_scale, info)
1564
+ jax_shape = self._jax(arg_shape, init_params)
1565
+ jax_scale = self._jax(arg_scale, init_params)
1698
1566
 
1699
1567
  # reparameterization trick Gompertz(s, r) = ln(1 - log(U(0, 1)) / s) / r
1700
1568
  def _jax_wrapped_distribution_gompertz(x, params, key):
1701
- shape, key, err1 = jax_shape(x, params, key)
1702
- scale, key, err2 = jax_scale(x, params, key)
1569
+ shape, key, err1, params = jax_shape(x, params, key)
1570
+ scale, key, err2, params = jax_scale(x, params, key)
1703
1571
  key, subkey = random.split(key)
1704
1572
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1705
1573
  sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
1706
1574
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1707
1575
  err = err1 | err2 | (out_of_bounds * ERR)
1708
- return sample, key, err
1576
+ return sample, key, err, params
1709
1577
 
1710
1578
  return _jax_wrapped_distribution_gompertz
1711
1579
 
1712
- def _jax_chisquare(self, expr, info):
1580
+ def _jax_chisquare(self, expr, init_params):
1713
1581
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CHISQUARE']
1714
1582
  JaxRDDLCompiler._check_num_args(expr, 1)
1715
1583
 
1716
1584
  arg_df, = expr.args
1717
- jax_df = self._jax(arg_df, info)
1585
+ jax_df = self._jax(arg_df, init_params)
1718
1586
 
1719
1587
  # use the fact that ChiSquare(df) = Gamma(df/2, 2)
1720
1588
  def _jax_wrapped_distribution_chisquare(x, params, key):
1721
- df, key, err1 = jax_df(x, params, key)
1589
+ df, key, err1, params = jax_df(x, params, key)
1722
1590
  key, subkey = random.split(key)
1723
1591
  shape = df / 2.0
1724
1592
  Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
1725
1593
  sample = 2.0 * Gamma
1726
1594
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1727
1595
  err = err1 | (out_of_bounds * ERR)
1728
- return sample, key, err
1596
+ return sample, key, err, params
1729
1597
 
1730
1598
  return _jax_wrapped_distribution_chisquare
1731
1599
 
1732
- def _jax_kumaraswamy(self, expr, info):
1600
+ def _jax_kumaraswamy(self, expr, init_params):
1733
1601
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KUMARASWAMY']
1734
1602
  JaxRDDLCompiler._check_num_args(expr, 2)
1735
1603
 
1736
1604
  arg_a, arg_b = expr.args
1737
- jax_a = self._jax(arg_a, info)
1738
- jax_b = self._jax(arg_b, info)
1605
+ jax_a = self._jax(arg_a, init_params)
1606
+ jax_b = self._jax(arg_b, init_params)
1739
1607
 
1740
1608
  # uses the reparameterization K(a, b) = (1 - (1 - U(0, 1))^{1/b})^{1/a}
1741
1609
  def _jax_wrapped_distribution_kumaraswamy(x, params, key):
1742
- a, key, err1 = jax_a(x, params, key)
1743
- b, key, err2 = jax_b(x, params, key)
1610
+ a, key, err1, params = jax_a(x, params, key)
1611
+ b, key, err2, params = jax_b(x, params, key)
1744
1612
  key, subkey = random.split(key)
1745
1613
  U = random.uniform(key=subkey, shape=jnp.shape(a), dtype=self.REAL)
1746
1614
  sample = jnp.power(1.0 - jnp.power(U, 1.0 / b), 1.0 / a)
1747
1615
  out_of_bounds = jnp.logical_not(jnp.all((a > 0) & (b > 0)))
1748
1616
  err = err1 | err2 | (out_of_bounds * ERR)
1749
- return sample, key, err
1617
+ return sample, key, err, params
1750
1618
 
1751
1619
  return _jax_wrapped_distribution_kumaraswamy
1752
1620
 
@@ -1754,7 +1622,7 @@ class JaxRDDLCompiler:
1754
1622
  # random variables with enum support
1755
1623
  # ===========================================================================
1756
1624
 
1757
- def _jax_discrete(self, expr, info, unnorm):
1625
+ def _jax_discrete(self, expr, init_params, unnorm):
1758
1626
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
1759
1627
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1760
1628
  ordered_args = self.traced.cached_sim_info(expr)
@@ -1766,10 +1634,10 @@ class JaxRDDLCompiler:
1766
1634
  discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1767
1635
  else:
1768
1636
  discrete_op = self.DISCRETE_HELPER
1769
- jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1637
+ jax_op = discrete_op(expr.id, init_params)
1770
1638
 
1771
1639
  # compile probability expressions
1772
- jax_probs = [self._jax(arg, info) for arg in ordered_args]
1640
+ jax_probs = [self._jax(arg, init_params) for arg in ordered_args]
1773
1641
 
1774
1642
  def _jax_wrapped_distribution_discrete(x, params, key):
1775
1643
 
@@ -1777,7 +1645,7 @@ class JaxRDDLCompiler:
1777
1645
  error = NORMAL
1778
1646
  prob = [None] * len(jax_probs)
1779
1647
  for (i, jax_prob) in enumerate(jax_probs):
1780
- prob[i], key, error_pdf = jax_prob(x, params, key)
1648
+ prob[i], key, error_pdf, params = jax_prob(x, params, key)
1781
1649
  error |= error_pdf
1782
1650
  prob = jnp.stack(prob, axis=-1)
1783
1651
  if unnorm:
@@ -1786,17 +1654,17 @@ class JaxRDDLCompiler:
1786
1654
 
1787
1655
  # dispatch to sampling subroutine
1788
1656
  key, subkey = random.split(key)
1789
- param = params.get(jax_param, None)
1790
- sample = jax_discrete(subkey, prob, param)
1657
+ sample, params = jax_op(subkey, prob, params)
1791
1658
  out_of_bounds = jnp.logical_not(jnp.logical_and(
1792
1659
  jnp.all(prob >= 0),
1793
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1660
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1661
+ ))
1794
1662
  error |= (out_of_bounds * ERR)
1795
- return sample, key, error
1663
+ return sample, key, error, params
1796
1664
 
1797
1665
  return _jax_wrapped_distribution_discrete
1798
1666
 
1799
- def _jax_discrete_pvar(self, expr, info, unnorm):
1667
+ def _jax_discrete_pvar(self, expr, init_params, unnorm):
1800
1668
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1801
1669
  JaxRDDLCompiler._check_num_args(expr, 2)
1802
1670
  _, args = expr.args
@@ -1807,28 +1675,28 @@ class JaxRDDLCompiler:
1807
1675
  discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1808
1676
  else:
1809
1677
  discrete_op = self.DISCRETE_HELPER
1810
- jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
1678
+ jax_op = discrete_op(expr.id, init_params)
1811
1679
 
1812
1680
  # compile probability function
1813
- jax_probs = self._jax(arg, info)
1681
+ jax_probs = self._jax(arg, init_params)
1814
1682
 
1815
1683
  def _jax_wrapped_distribution_discrete_pvar(x, params, key):
1816
1684
 
1817
1685
  # sample probabilities
1818
- prob, key, error = jax_probs(x, params, key)
1686
+ prob, key, error, params = jax_probs(x, params, key)
1819
1687
  if unnorm:
1820
1688
  normalizer = jnp.sum(prob, axis=-1, keepdims=True)
1821
1689
  prob = prob / normalizer
1822
1690
 
1823
1691
  # dispatch to sampling subroutine
1824
1692
  key, subkey = random.split(key)
1825
- param = params.get(jax_param, None)
1826
- sample = jax_discrete(subkey, prob, param)
1693
+ sample, params = jax_op(subkey, prob, params)
1827
1694
  out_of_bounds = jnp.logical_not(jnp.logical_and(
1828
1695
  jnp.all(prob >= 0),
1829
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1696
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1697
+ ))
1830
1698
  error |= (out_of_bounds * ERR)
1831
- return sample, key, error
1699
+ return sample, key, error, params
1832
1700
 
1833
1701
  return _jax_wrapped_distribution_discrete_pvar
1834
1702
 
@@ -1836,68 +1704,69 @@ class JaxRDDLCompiler:
1836
1704
  # random vectors
1837
1705
  # ===========================================================================
1838
1706
 
1839
- def _jax_random_vector(self, expr, info):
1707
+ def _jax_random_vector(self, expr, init_params):
1840
1708
  _, name = expr.etype
1841
1709
  if name == 'MultivariateNormal':
1842
- return self._jax_multivariate_normal(expr, info)
1710
+ return self._jax_multivariate_normal(expr, init_params)
1843
1711
  elif name == 'MultivariateStudent':
1844
- return self._jax_multivariate_student(expr, info)
1712
+ return self._jax_multivariate_student(expr, init_params)
1845
1713
  elif name == 'Dirichlet':
1846
- return self._jax_dirichlet(expr, info)
1714
+ return self._jax_dirichlet(expr, init_params)
1847
1715
  elif name == 'Multinomial':
1848
- return self._jax_multinomial(expr, info)
1716
+ return self._jax_multinomial(expr, init_params)
1849
1717
  else:
1850
1718
  raise RDDLNotImplementedError(
1851
1719
  f'Distribution {name} is not supported.\n' +
1852
1720
  print_stack_trace(expr))
1853
1721
 
1854
- def _jax_multivariate_normal(self, expr, info):
1722
+ def _jax_multivariate_normal(self, expr, init_params):
1855
1723
  _, args = expr.args
1856
1724
  mean, cov = args
1857
- jax_mean = self._jax(mean, info)
1858
- jax_cov = self._jax(cov, info)
1725
+ jax_mean = self._jax(mean, init_params)
1726
+ jax_cov = self._jax(cov, init_params)
1859
1727
  index, = self.traced.cached_sim_info(expr)
1860
1728
 
1861
1729
  # reparameterization trick MN(m, LL') = LZ + m, where Z ~ Normal(0, 1)
1862
1730
  def _jax_wrapped_distribution_multivariate_normal(x, params, key):
1863
1731
 
1864
1732
  # sample the mean and covariance
1865
- sample_mean, key, err1 = jax_mean(x, params, key)
1866
- sample_cov, key, err2 = jax_cov(x, params, key)
1733
+ sample_mean, key, err1, params = jax_mean(x, params, key)
1734
+ sample_cov, key, err2, params = jax_cov(x, params, key)
1867
1735
 
1868
1736
  # sample Normal(0, 1)
1869
1737
  key, subkey = random.split(key)
1870
1738
  Z = random.normal(
1871
1739
  key=subkey,
1872
1740
  shape=jnp.shape(sample_mean) + (1,),
1873
- dtype=self.REAL)
1741
+ dtype=self.REAL
1742
+ )
1874
1743
 
1875
1744
  # compute L s.t. cov = L * L' and reparameterize
1876
1745
  L = jnp.linalg.cholesky(sample_cov)
1877
1746
  sample = jnp.matmul(L, Z)[..., 0] + sample_mean
1878
1747
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1879
1748
  err = err1 | err2
1880
- return sample, key, err
1749
+ return sample, key, err, params
1881
1750
 
1882
1751
  return _jax_wrapped_distribution_multivariate_normal
1883
1752
 
1884
- def _jax_multivariate_student(self, expr, info):
1753
+ def _jax_multivariate_student(self, expr, init_params):
1885
1754
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTIVARIATE_STUDENT']
1886
1755
 
1887
1756
  _, args = expr.args
1888
1757
  mean, cov, df = args
1889
- jax_mean = self._jax(mean, info)
1890
- jax_cov = self._jax(cov, info)
1891
- jax_df = self._jax(df, info)
1758
+ jax_mean = self._jax(mean, init_params)
1759
+ jax_cov = self._jax(cov, init_params)
1760
+ jax_df = self._jax(df, init_params)
1892
1761
  index, = self.traced.cached_sim_info(expr)
1893
1762
 
1894
1763
  # reparameterization trick MN(m, LL') = LZ + m, where Z ~ StudentT(0, 1)
1895
1764
  def _jax_wrapped_distribution_multivariate_student(x, params, key):
1896
1765
 
1897
1766
  # sample the mean and covariance and degrees of freedom
1898
- sample_mean, key, err1 = jax_mean(x, params, key)
1899
- sample_cov, key, err2 = jax_cov(x, params, key)
1900
- sample_df, key, err3 = jax_df(x, params, key)
1767
+ sample_mean, key, err1, params = jax_mean(x, params, key)
1768
+ sample_cov, key, err2, params = jax_cov(x, params, key)
1769
+ sample_df, key, err3, params = jax_df(x, params, key)
1901
1770
  out_of_bounds = jnp.logical_not(jnp.all(sample_df > 0))
1902
1771
 
1903
1772
  # sample StudentT(0, 1, df) -- broadcast df to same shape as cov
@@ -1908,50 +1777,51 @@ class JaxRDDLCompiler:
1908
1777
  key=subkey,
1909
1778
  df=sample_df,
1910
1779
  shape=jnp.shape(sample_df),
1911
- dtype=self.REAL)
1780
+ dtype=self.REAL
1781
+ )
1912
1782
 
1913
1783
  # compute L s.t. cov = L * L' and reparameterize
1914
1784
  L = jnp.linalg.cholesky(sample_cov)
1915
1785
  sample = jnp.matmul(L, Z)[..., 0] + sample_mean
1916
1786
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1917
1787
  error = err1 | err2 | err3 | (out_of_bounds * ERR)
1918
- return sample, key, error
1788
+ return sample, key, error, params
1919
1789
 
1920
1790
  return _jax_wrapped_distribution_multivariate_student
1921
1791
 
1922
- def _jax_dirichlet(self, expr, info):
1792
+ def _jax_dirichlet(self, expr, init_params):
1923
1793
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DIRICHLET']
1924
1794
 
1925
1795
  _, args = expr.args
1926
1796
  alpha, = args
1927
- jax_alpha = self._jax(alpha, info)
1797
+ jax_alpha = self._jax(alpha, init_params)
1928
1798
  index, = self.traced.cached_sim_info(expr)
1929
1799
 
1930
1800
  # sample Gamma(alpha_i, 1) and normalize across i
1931
1801
  def _jax_wrapped_distribution_dirichlet(x, params, key):
1932
- alpha, key, error = jax_alpha(x, params, key)
1802
+ alpha, key, error, params = jax_alpha(x, params, key)
1933
1803
  out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
1934
1804
  error |= (out_of_bounds * ERR)
1935
1805
  key, subkey = random.split(key)
1936
1806
  Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
1937
1807
  sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
1938
1808
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1939
- return sample, key, error
1809
+ return sample, key, error, params
1940
1810
 
1941
1811
  return _jax_wrapped_distribution_dirichlet
1942
1812
 
1943
- def _jax_multinomial(self, expr, info):
1813
+ def _jax_multinomial(self, expr, init_params):
1944
1814
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTINOMIAL']
1945
1815
 
1946
1816
  _, args = expr.args
1947
1817
  trials, prob = args
1948
- jax_trials = self._jax(trials, info)
1949
- jax_prob = self._jax(prob, info)
1818
+ jax_trials = self._jax(trials, init_params)
1819
+ jax_prob = self._jax(prob, init_params)
1950
1820
  index, = self.traced.cached_sim_info(expr)
1951
1821
 
1952
1822
  def _jax_wrapped_distribution_multinomial(x, params, key):
1953
- trials, key, err1 = jax_trials(x, params, key)
1954
- prob, key, err2 = jax_prob(x, params, key)
1823
+ trials, key, err1, params = jax_trials(x, params, key)
1824
+ prob, key, err2, params = jax_prob(x, params, key)
1955
1825
  trials = jnp.asarray(trials, dtype=self.REAL)
1956
1826
  prob = jnp.asarray(prob, dtype=self.REAL)
1957
1827
  key, subkey = random.split(key)
@@ -1961,9 +1831,10 @@ class JaxRDDLCompiler:
1961
1831
  out_of_bounds = jnp.logical_not(jnp.all(
1962
1832
  (prob >= 0)
1963
1833
  & jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1964
- & (trials >= 0)))
1834
+ & (trials >= 0)
1835
+ ))
1965
1836
  error = err1 | err2 | (out_of_bounds * ERR)
1966
- return sample, key, error
1837
+ return sample, key, error, params
1967
1838
 
1968
1839
  return _jax_wrapped_distribution_multinomial
1969
1840
 
@@ -1971,57 +1842,57 @@ class JaxRDDLCompiler:
1971
1842
  # matrix algebra
1972
1843
  # ===========================================================================
1973
1844
 
1974
- def _jax_matrix(self, expr, info):
1845
+ def _jax_matrix(self, expr, init_params):
1975
1846
  _, op = expr.etype
1976
1847
  if op == 'det':
1977
- return self._jax_matrix_det(expr, info)
1848
+ return self._jax_matrix_det(expr, init_params)
1978
1849
  elif op == 'inverse':
1979
- return self._jax_matrix_inv(expr, info, pseudo=False)
1850
+ return self._jax_matrix_inv(expr, init_params, pseudo=False)
1980
1851
  elif op == 'pinverse':
1981
- return self._jax_matrix_inv(expr, info, pseudo=True)
1852
+ return self._jax_matrix_inv(expr, init_params, pseudo=True)
1982
1853
  elif op == 'cholesky':
1983
- return self._jax_matrix_cholesky(expr, info)
1854
+ return self._jax_matrix_cholesky(expr, init_params)
1984
1855
  else:
1985
1856
  raise RDDLNotImplementedError(
1986
1857
  f'Matrix operation {op} is not supported.\n' +
1987
1858
  print_stack_trace(expr))
1988
1859
 
1989
- def _jax_matrix_det(self, expr, info):
1860
+ def _jax_matrix_det(self, expr, init_params):
1990
1861
  * _, arg = expr.args
1991
- jax_arg = self._jax(arg, info)
1862
+ jax_arg = self._jax(arg, init_params)
1992
1863
 
1993
1864
  def _jax_wrapped_matrix_operation_det(x, params, key):
1994
- sample_arg, key, error = jax_arg(x, params, key)
1865
+ sample_arg, key, error, params = jax_arg(x, params, key)
1995
1866
  sample = jnp.linalg.det(sample_arg)
1996
- return sample, key, error
1867
+ return sample, key, error, params
1997
1868
 
1998
1869
  return _jax_wrapped_matrix_operation_det
1999
1870
 
2000
- def _jax_matrix_inv(self, expr, info, pseudo):
1871
+ def _jax_matrix_inv(self, expr, init_params, pseudo):
2001
1872
  _, arg = expr.args
2002
- jax_arg = self._jax(arg, info)
1873
+ jax_arg = self._jax(arg, init_params)
2003
1874
  indices = self.traced.cached_sim_info(expr)
2004
1875
  op = jnp.linalg.pinv if pseudo else jnp.linalg.inv
2005
1876
 
2006
1877
  def _jax_wrapped_matrix_operation_inv(x, params, key):
2007
- sample_arg, key, error = jax_arg(x, params, key)
1878
+ sample_arg, key, error, params = jax_arg(x, params, key)
2008
1879
  sample = op(sample_arg)
2009
1880
  sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
2010
- return sample, key, error
1881
+ return sample, key, error, params
2011
1882
 
2012
1883
  return _jax_wrapped_matrix_operation_inv
2013
1884
 
2014
- def _jax_matrix_cholesky(self, expr, info):
1885
+ def _jax_matrix_cholesky(self, expr, init_params):
2015
1886
  _, arg = expr.args
2016
- jax_arg = self._jax(arg, info)
1887
+ jax_arg = self._jax(arg, init_params)
2017
1888
  indices = self.traced.cached_sim_info(expr)
2018
1889
  op = jnp.linalg.cholesky
2019
1890
 
2020
1891
  def _jax_wrapped_matrix_operation_cholesky(x, params, key):
2021
- sample_arg, key, error = jax_arg(x, params, key)
1892
+ sample_arg, key, error, params = jax_arg(x, params, key)
2022
1893
  sample = op(sample_arg)
2023
1894
  sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
2024
- return sample, key, error
1895
+ return sample, key, error, params
2025
1896
 
2026
1897
  return _jax_wrapped_matrix_operation_cholesky
2027
1898