pyRDDLGym-jax 2.7__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -33
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.7.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,14 @@
14
14
 
15
15
 
16
16
  from functools import partial
17
+ import termcolor
17
18
  import traceback
18
19
  from typing import Any, Callable, Dict, List, Optional
19
20
 
20
21
  import jax
21
22
  import jax.numpy as jnp
22
23
  import jax.random as random
24
+ import jax.scipy as scipy
23
25
 
24
26
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
25
27
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -36,8 +38,6 @@ from pyRDDLGym.core.debug.exception import (
36
38
  from pyRDDLGym.core.debug.logger import Logger
37
39
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
38
40
 
39
- from pyRDDLGym_jax.core.logic import ExactLogic
40
-
41
41
  # more robust approach - if user does not have this or broken try to continue
42
42
  try:
43
43
  from tensorflow_probability.substrates import jax as tfp
@@ -53,12 +53,11 @@ class JaxRDDLCompiler:
53
53
  All operations are identical to their numpy equivalents.
54
54
  '''
55
55
 
56
- def __init__(self, rddl: RDDLLiftedModel,
56
+ def __init__(self, rddl: RDDLLiftedModel, *args,
57
57
  allow_synchronous_state: bool=True,
58
58
  logger: Optional[Logger]=None,
59
59
  use64bit: bool=False,
60
- compile_non_fluent_exact: bool=True,
61
- python_functions: Optional[Dict[str, Callable]]=None) -> None:
60
+ python_functions: Optional[Dict[str, Callable]]=None, **kwargs) -> None:
62
61
  '''Creates a new RDDL to Jax compiler.
63
62
 
64
63
  :param rddl: the RDDL model to compile into Jax
@@ -66,10 +65,17 @@ class JaxRDDLCompiler:
66
65
  on each other
67
66
  :param logger: to log information about compilation to file
68
67
  :param use64bit: whether to use 64 bit arithmetic
69
- :param compile_non_fluent_exact: whether non-fluent expressions
70
- are always compiled using exact JAX expressions
71
68
  :param python_functions: dictionary of external Python functions to call from RDDL
72
69
  '''
70
+
71
+ # warn about unused parameters
72
+ if args:
73
+ print(termcolor.colored(
74
+ f'[WARN] JaxRDDLCompiler received invalid args {args}.', 'yellow'))
75
+ if kwargs:
76
+ print(termcolor.colored(
77
+ f'[WARN] JaxRDDLCompiler received invalid kwargs {kwargs}.', 'yellow'))
78
+
73
79
  self.rddl = rddl
74
80
  self.logger = logger
75
81
  # jax.config.update('jax_log_compiles', True) # for testing ONLY
@@ -86,7 +92,7 @@ class JaxRDDLCompiler:
86
92
  self.JAX_TYPES = {
87
93
  'int': self.INT,
88
94
  'real': self.REAL,
89
- 'bool': bool
95
+ 'bool': jnp.bool_
90
96
  }
91
97
 
92
98
  # compile initial values
@@ -94,6 +100,7 @@ class JaxRDDLCompiler:
94
100
  self.init_values = initializer.initialize()
95
101
 
96
102
  # compute dependency graph for CPFs and sort them by evaluation order
103
+ self.allow_synchronous_state = allow_synchronous_state
97
104
  sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
98
105
  self.levels = sorter.compute_levels()
99
106
 
@@ -101,10 +108,12 @@ class JaxRDDLCompiler:
101
108
  tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
102
109
  self.traced = tracer.trace()
103
110
 
104
- # extract the box constraints on actions
111
+ # external python functions
105
112
  if python_functions is None:
106
113
  python_functions = {}
107
114
  self.python_functions = python_functions
115
+
116
+ # extract the box constraints on actions
108
117
  simulator = RDDLSimulatorPrecompiled(
109
118
  rddl=self.rddl,
110
119
  init_values=self.init_values,
@@ -112,75 +121,90 @@ class JaxRDDLCompiler:
112
121
  trace_info=self.traced,
113
122
  python_functions=python_functions
114
123
  )
115
- constraints = RDDLConstraints(simulator, vectorized=True)
116
- self.constraints = constraints
117
-
118
- # basic operations - these can be override in subclasses
119
- self.compile_non_fluent_exact = compile_non_fluent_exact
120
- self.AGGREGATION_BOOL = {'forall', 'exists'}
121
- self.EXACT_OPS = ExactLogic(use64bit=self.use64bit).get_operator_dicts()
122
- self.OPS = self.EXACT_OPS
124
+ self.constraints = RDDLConstraints(simulator, vectorized=True)
125
+
126
+ def get_kwargs(self) -> Dict[str, Any]:
127
+ '''Returns a dictionary of configurable parameter name: parameter value pairs.
128
+ '''
129
+ return {
130
+ 'allow_synchronous_state': self.allow_synchronous_state,
131
+ 'use64bit': self.use64bit,
132
+ 'python_functions': self.python_functions
133
+ }
134
+
135
+ def split_fluent_nonfluent(self, values):
136
+ '''Splits given values dictionary into fluent and non-fluent dictionaries.
137
+ '''
138
+ nonfluents = self.rddl.non_fluents
139
+ fls = {name: value for (name, value) in values.items() if name not in nonfluents}
140
+ nfls = {name: value for (name, value) in values.items() if name in nonfluents}
141
+ return fls, nfls
123
142
 
124
143
  # ===========================================================================
125
144
  # main compilation subroutines
126
145
  # ===========================================================================
127
146
 
128
- def compile(self, log_jax_expr: bool=False, heading: str='') -> None:
147
+ def compile(self, log_jax_expr: bool=False,
148
+ heading: str='',
149
+ extra_aux: Dict[str, Any]={}) -> None:
129
150
  '''Compiles the current RDDL into Jax expressions.
130
151
 
131
152
  :param log_jax_expr: whether to pretty-print the compiled Jax functions
132
153
  to the log file
133
154
  :param heading: the heading to print before compilation information
155
+ :param extra_aux: extra info to save during compilations
134
156
  '''
135
- init_params = {}
136
- self.invariants = self._compile_constraints(self.rddl.invariants, init_params)
137
- self.preconditions = self._compile_constraints(self.rddl.preconditions, init_params)
138
- self.terminations = self._compile_constraints(self.rddl.terminations, init_params)
139
- self.cpfs = self._compile_cpfs(init_params)
140
- self.reward = self._compile_reward(init_params)
141
- self.model_params = init_params
142
-
157
+ self.model_aux = {'params': {}, 'overriden': {}}
158
+ self.model_aux.update(extra_aux)
159
+
160
+ self.invariants = self._compile_constraints(self.rddl.invariants, self.model_aux)
161
+ self.preconditions = self._compile_constraints(self.rddl.preconditions, self.model_aux)
162
+ self.terminations = self._compile_constraints(self.rddl.terminations, self.model_aux)
163
+ self.cpfs = self._compile_cpfs(self.model_aux)
164
+ self.reward = self._compile_reward(self.model_aux)
165
+
166
+ # add compiled jax expression to logger
143
167
  if log_jax_expr and self.logger is not None:
144
- printed = self.print_jax()
145
- printed_cpfs = '\n\n'.join(f'{k}: {v}'
146
- for (k, v) in printed['cpfs'].items())
147
- printed_reward = printed['reward']
148
- printed_invariants = '\n\n'.join(v for v in printed['invariants'])
149
- printed_preconds = '\n\n'.join(v for v in printed['preconditions'])
150
- printed_terminals = '\n\n'.join(v for v in printed['terminations'])
151
- printed_params = '\n'.join(f'{k}: {v}' for (k, v) in init_params.items())
152
- message = (
153
- f'[info] {heading}\n'
154
- f'[info] compiled JAX CPFs:\n\n'
155
- f'{printed_cpfs}\n\n'
156
- f'[info] compiled JAX reward:\n\n'
157
- f'{printed_reward}\n\n'
158
- f'[info] compiled JAX invariants:\n\n'
159
- f'{printed_invariants}\n\n'
160
- f'[info] compiled JAX preconditions:\n\n'
161
- f'{printed_preconds}\n\n'
162
- f'[info] compiled JAX terminations:\n\n'
163
- f'{printed_terminals}\n'
164
- f'[info] model parameters:\n'
165
- f'{printed_params}\n'
166
- )
167
- self.logger.log(message)
168
-
169
- def _compile_constraints(self, constraints, init_params):
170
- return [self._jax(expr, init_params, dtype=bool) for expr in constraints]
171
-
172
- def _compile_cpfs(self, init_params):
168
+ self._log_printed_jax(heading)
169
+
170
+ def _log_printed_jax(self, heading=''):
171
+ printed = self.print_jax()
172
+ printed_cpfs = '\n\n'.join(f'{k}: {v}' for (k, v) in printed['cpfs'].items())
173
+ printed_reward = printed['reward']
174
+ printed_invariants = '\n\n'.join(v for v in printed['invariants'])
175
+ printed_preconds = '\n\n'.join(v for v in printed['preconditions'])
176
+ printed_terminals = '\n\n'.join(v for v in printed['terminations'])
177
+ printed_params = '\n'.join(f'{k}: {v}' for (k, v) in self.model_aux['params'].items())
178
+ self.logger.log(
179
+ f'[info] {heading}\n'
180
+ f'[info] compiled JAX CPFs:\n\n'
181
+ f'{printed_cpfs}\n\n'
182
+ f'[info] compiled JAX reward:\n\n'
183
+ f'{printed_reward}\n\n'
184
+ f'[info] compiled JAX invariants:\n\n'
185
+ f'{printed_invariants}\n\n'
186
+ f'[info] compiled JAX preconditions:\n\n'
187
+ f'{printed_preconds}\n\n'
188
+ f'[info] compiled JAX terminations:\n\n'
189
+ f'{printed_terminals}\n'
190
+ f'[info] model parameters:\n'
191
+ f'{printed_params}\n'
192
+ )
193
+
194
+ def _compile_constraints(self, constraints, aux):
195
+ return [self._jax(expr, aux, dtype=jnp.bool_) for expr in constraints]
196
+
197
+ def _compile_cpfs(self, aux):
173
198
  jax_cpfs = {}
174
199
  for cpfs in self.levels.values():
175
200
  for cpf in cpfs:
176
201
  _, expr = self.rddl.cpfs[cpf]
177
- prange = self.rddl.variable_ranges[cpf]
178
- dtype = self.JAX_TYPES.get(prange, self.INT)
179
- jax_cpfs[cpf] = self._jax(expr, init_params, dtype=dtype)
202
+ dtype = self.JAX_TYPES.get(self.rddl.variable_ranges[cpf], self.INT)
203
+ jax_cpfs[cpf] = self._jax(expr, aux, dtype=dtype)
180
204
  return jax_cpfs
181
205
 
182
- def _compile_reward(self, init_params):
183
- return self._jax(self.rddl.reward, init_params, dtype=self.REAL)
206
+ def _compile_reward(self, aux):
207
+ return self._jax(self.rddl.reward, aux, dtype=self.REAL)
184
208
 
185
209
  def _extract_inequality_constraint(self, expr):
186
210
  result = []
@@ -208,55 +232,110 @@ class JaxRDDLCompiler:
208
232
  result.extend(self._extract_equality_constraint(arg))
209
233
  return result
210
234
 
211
- def _jax_nonlinear_constraints(self, init_params):
212
- rddl = self.rddl
213
-
214
- # extract the non-box inequality constraints on actions
215
- inequalities = [constr
216
- for (i, expr) in enumerate(rddl.preconditions)
217
- for constr in self._extract_inequality_constraint(expr)
218
- if not self.constraints.is_box_preconditions[i]]
219
-
220
- # compile them to JAX and write as h(s, a) <= 0
221
- jax_op = ExactLogic.exact_binary_function(jnp.subtract)
222
- jax_inequalities = []
223
- for (left, right) in inequalities:
224
- jax_lhs = self._jax(left, init_params)
225
- jax_rhs = self._jax(right, init_params)
226
- jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
227
- jax_inequalities.append(jax_constr)
228
-
229
- # extract the non-box equality constraints on actions
230
- equalities = [constr
231
- for (i, expr) in enumerate(rddl.preconditions)
232
- for constr in self._extract_equality_constraint(expr)
233
- if not self.constraints.is_box_preconditions[i]]
234
-
235
- # compile them to JAX and write as g(s, a) == 0
236
- jax_equalities = []
237
- for (left, right) in equalities:
238
- jax_lhs = self._jax(left, init_params)
239
- jax_rhs = self._jax(right, init_params)
240
- jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
241
- jax_equalities.append(jax_constr)
242
-
235
+ def _jax_nonlinear_constraints(self, aux):
236
+ jax_equalities, jax_inequalities = [], []
237
+ for (i, expr) in enumerate(self.rddl.preconditions):
238
+ if not self.constraints.is_box_preconditions[i]:
239
+
240
+ # compile inequalities to JAX and write as h(s, a) <= 0
241
+ for (left, right) in self._extract_inequality_constraint(expr):
242
+ jax_lhs = self._jax(left, aux)
243
+ jax_rhs = self._jax(right, aux)
244
+ jax_constr = self._jax_binary(
245
+ jax_lhs, jax_rhs, jnp.subtract, at_least_int=True)
246
+ jax_inequalities.append(jax_constr)
247
+
248
+ # compile equalities to JAX and write as g(s, a) == 0
249
+ for (left, right) in self._extract_equality_constraint(expr):
250
+ jax_lhs = self._jax(left, aux)
251
+ jax_rhs = self._jax(right, aux)
252
+ jax_constr = self._jax_binary(
253
+ jax_lhs, jax_rhs, jnp.subtract, at_least_int=True)
254
+ jax_equalities.append(jax_constr)
243
255
  return jax_inequalities, jax_equalities
244
256
 
257
+ def _jax_preconditions(self):
258
+ preconds = self.preconditions
259
+ def _jax_wrapped_preconditions(key, errors, fls, nfls, params):
260
+ precond_check = jnp.array(True, dtype=jnp.bool_)
261
+ for precond in preconds:
262
+ sample, key, err, params = precond(fls, nfls, params, key)
263
+ precond_check = jnp.logical_and(precond_check, sample)
264
+ errors = errors | err
265
+ return precond_check, key, errors, params
266
+ return _jax_wrapped_preconditions
267
+
268
+ def _jax_inequalities(self, aux_constr):
269
+ inequality_fns, equality_fns = self._jax_nonlinear_constraints(aux_constr)
270
+ def _jax_wrapped_inequalities(key, errors, fls, nfls, params):
271
+ inequalities, equalities = [], []
272
+ for constraint in inequality_fns:
273
+ sample, key, err, params = constraint(fls, nfls, params, key)
274
+ inequalities.append(sample)
275
+ errors = errors | err
276
+ for constraint in equality_fns:
277
+ sample, key, err, params = constraint(fls, nfls, params, key)
278
+ equalities.append(sample)
279
+ errors = errors | err
280
+ return (inequalities, equalities), key, errors, params
281
+ return _jax_wrapped_inequalities
282
+
283
+ def _jax_cpfs(self):
284
+ cpfs = self.cpfs
285
+ def _jax_wrapped_cpfs(key, errors, fls, nfls, params):
286
+ fls = fls.copy()
287
+ for (name, cpf) in cpfs.items():
288
+ fls[name], key, err, params = cpf(fls, nfls, params, key)
289
+ errors = errors | err
290
+ return fls, key, errors, params
291
+ return _jax_wrapped_cpfs
292
+
293
+ def _jax_reward(self):
294
+ reward_fn = self.reward
295
+ def _jax_wrapped_reward(key, errors, fls, nfls, params):
296
+ reward, key, err, params = reward_fn(fls, nfls, params, key)
297
+ errors = errors | err
298
+ return reward, key, errors, params
299
+ return _jax_wrapped_reward
300
+
301
+ def _jax_invariants(self):
302
+ invariants = self.invariants
303
+ def _jax_wrapped_invariants(key, errors, fls, nfls, params):
304
+ invariant_check = jnp.array(True, dtype=jnp.bool_)
305
+ for invariant in invariants:
306
+ sample, key, err, params = invariant(fls, nfls, params, key)
307
+ invariant_check = jnp.logical_and(invariant_check, sample)
308
+ errors = errors | err
309
+ return invariant_check, key, errors, params
310
+ return _jax_wrapped_invariants
311
+
312
+ def _jax_terminations(self):
313
+ terminations = self.terminations
314
+ def _jax_wrapped_terminations(key, errors, fls, nfls, params):
315
+ terminated_check = jnp.array(False, dtype=jnp.bool_)
316
+ for terminal in terminations:
317
+ sample, key, err, params = terminal(fls, nfls, params, key)
318
+ terminated_check = jnp.logical_or(terminated_check, sample)
319
+ errors = errors | err
320
+ return terminated_check, key, errors, params
321
+ return _jax_wrapped_terminations
322
+
245
323
  def compile_transition(self, check_constraints: bool=False,
246
324
  constraint_func: bool=False,
247
- init_params_constr: Dict[str, Any]={},
248
- cache_path_info: bool=False) -> Callable:
325
+ cache_path_info: bool=False,
326
+ aux_constr: Dict[str, Any]={}) -> Callable:
249
327
  '''Compiles the current RDDL model into a JAX transition function that
250
328
  samples the next state.
251
329
 
252
330
  The arguments of the returned function is:
253
331
  - key is the PRNG key
254
332
  - actions is the dict of action tensors
255
- - subs is the dict of current pvar value tensors
256
- - model_params is a dict of parameters for the relaxed model.
333
+ - fls is the dict of current fluent pvar tensors
334
+ - nfls is the dict of nonfluent pvar tensors
335
+ - params is a dict of parameters for the relaxed model.
257
336
 
258
337
  The returned value of the function is:
259
- - subs is the returned next epoch fluent values
338
+ - fls is the returned next epoch fluent values
260
339
  - log includes all the auxiliary information about constraints
261
340
  satisfied, errors, etc.
262
341
 
@@ -284,104 +363,118 @@ class JaxRDDLCompiler:
284
363
  in addition to the usual outputs
285
364
  :param cache_path_info: whether to save full path traces as part of the log
286
365
  '''
287
- NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
288
- rddl = self.rddl
289
- reward_fn, cpfs, preconds, invariants, terminals = \
290
- self.reward, self.cpfs, self.preconditions, self.invariants, self.terminations
366
+ NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
291
367
 
292
- # compile constraint information
368
+ # compile all components of the RDDL
369
+ cpf_fn = self._jax_cpfs()
370
+ reward_fn = self._jax_reward()
371
+
372
+ # compile optional constraints
373
+ precond_fn = invariant_fn = terminal_fn = None
374
+ if check_constraints:
375
+ precond_fn = self._jax_preconditions()
376
+ invariant_fn = self._jax_invariants()
377
+ terminal_fn = self._jax_terminations()
378
+
379
+ # compile optional inequalities
380
+ ineq_fn = None
293
381
  if constraint_func:
294
- inequality_fns, equality_fns = self._jax_nonlinear_constraints(
295
- init_params_constr)
296
- else:
297
- inequality_fns, equality_fns = None, None
298
-
382
+ ineq_fn = self._jax_inequalities(aux_constr)
383
+
299
384
  # do a single step update from the RDDL model
300
- def _jax_wrapped_single_step(key, actions, subs, model_params):
385
+ def _jax_wrapped_single_step(key, actions, fls, nfls, params):
301
386
  errors = NORMAL
302
- subs.update(actions)
387
+
388
+ fls = fls.copy()
389
+ fls.update(actions)
303
390
 
304
391
  # check action preconditions
305
- precond_check = True
306
392
  if check_constraints:
307
- for precond in preconds:
308
- sample, key, err, model_params = precond(subs, model_params, key)
309
- precond_check = jnp.logical_and(precond_check, sample)
310
- errors |= err
393
+ precond, key, errors, params = precond_fn(key, errors, fls, nfls, params)
394
+ else:
395
+ precond = jnp.array(True, dtype=jnp.bool_)
311
396
 
312
397
  # compute h(s, a) <= 0 and g(s, a) == 0 constraint functions
313
- inequalities, equalities = [], []
314
398
  if constraint_func:
315
- for constraint in inequality_fns:
316
- sample, key, err, model_params = constraint(subs, model_params, key)
317
- inequalities.append(sample)
318
- errors |= err
319
- for constraint in equality_fns:
320
- sample, key, err, model_params = constraint(subs, model_params, key)
321
- equalities.append(sample)
322
- errors |= err
399
+ (inequalities, equalities), key, errors, params = ineq_fn(
400
+ key, errors, fls, nfls, params)
401
+ else:
402
+ inequalities, equalities = [], []
323
403
 
324
404
  # calculate CPFs in topological order
325
- for (name, cpf) in cpfs.items():
326
- subs[name], key, err, model_params = cpf(subs, model_params, key)
327
- errors |= err
328
-
329
- # calculate the immediate reward
330
- reward, key, err, model_params = reward_fn(subs, model_params, key)
331
- errors |= err
405
+ fls, key, errors, params = cpf_fn(key, errors, fls, nfls, params)
406
+ fluents = fls if cache_path_info else {}
332
407
 
333
- # calculate fluent values
334
- if cache_path_info:
335
- fluents = {name: values for (name, values) in subs.items()
336
- if name not in rddl.non_fluents}
337
- else:
338
- fluents = {}
408
+ # calculate the immediate reward
409
+ reward, key, errors, params = reward_fn(key, errors, fls, nfls, params)
339
410
 
340
411
  # set the next state to the current state
341
- for (state, next_state) in rddl.next_state.items():
342
- subs[state] = subs[next_state]
412
+ for (state, next_state) in self.rddl.next_state.items():
413
+ fls[state] = fls[next_state]
343
414
 
344
- # check the state invariants
345
- invariant_check = True
415
+ # check the state invariants and termination
346
416
  if check_constraints:
347
- for invariant in invariants:
348
- sample, key, err, model_params = invariant(subs, model_params, key)
349
- invariant_check = jnp.logical_and(invariant_check, sample)
350
- errors |= err
351
-
352
- # check the termination (TODO: zero out reward in s if terminated)
353
- terminated_check = False
354
- if check_constraints:
355
- for terminal in terminals:
356
- sample, key, err, model_params = terminal(subs, model_params, key)
357
- terminated_check = jnp.logical_or(terminated_check, sample)
358
- errors |= err
359
-
417
+ invariant, key, errors, params = invariant_fn(key, errors, fls, nfls, params)
418
+ terminated, key, errors, params = terminal_fn(key, errors, fls, nfls, params)
419
+ else:
420
+ invariant = jnp.array(True, dtype=jnp.bool_)
421
+ terminated = jnp.array(False, dtype=jnp.bool_)
422
+
360
423
  # prepare the return value
361
424
  log = {
362
425
  'fluents': fluents,
363
426
  'reward': reward,
364
427
  'error': errors,
365
- 'precondition': precond_check,
366
- 'invariant': invariant_check,
367
- 'termination': terminated_check
368
- }
369
- if constraint_func:
370
- log['inequalities'] = inequalities
371
- log['equalities'] = equalities
372
-
373
- return subs, log, model_params
428
+ 'precondition': precond,
429
+ 'invariant': invariant,
430
+ 'termination': terminated,
431
+ 'inequalities': inequalities,
432
+ 'equalities': equalities
433
+ }
434
+ return fls, log, params
374
435
 
375
436
  return _jax_wrapped_single_step
376
437
 
438
+ def _compile_policy_step(self, policy, transition_fn):
439
+ def _jax_wrapped_policy_step(key, policy_params, hyperparams, step, fls, nfls,
440
+ model_params):
441
+ key, subkey = random.split(key)
442
+ actions = policy(key, policy_params, hyperparams, step, fls)
443
+ return transition_fn(subkey, actions, fls, nfls, model_params)
444
+ return _jax_wrapped_policy_step
445
+
446
+ def _compile_batched_policy_step(self, policy_step_fn, n_batch, model_params_reduction):
447
+ def _jax_wrapped_batched_policy_step(carry, step):
448
+ key, policy_params, hyperparams, fls, nfls, model_params = carry
449
+ keys = random.split(key, num=1 + n_batch)
450
+ key, subkeys = keys[0], keys[1:]
451
+ fls, log, model_params = jax.vmap(
452
+ policy_step_fn, in_axes=(0, None, None, None, 0, None, None)
453
+ )(subkeys, policy_params, hyperparams, step, fls, nfls, model_params)
454
+ model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
455
+ carry = (key, policy_params, hyperparams, fls, nfls, model_params)
456
+ return carry, log
457
+ return _jax_wrapped_batched_policy_step
458
+
459
+ def _compile_unrolled_policy_step(self, batched_policy_step_fn, n_steps):
460
+ def _jax_wrapped_batched_policy_rollout(key, policy_params, hyperparams, fls, nfls,
461
+ model_params):
462
+ start = (key, policy_params, hyperparams, fls, nfls, model_params)
463
+ steps = jnp.arange(n_steps)
464
+ end, log = jax.lax.scan(batched_policy_step_fn, start, steps)
465
+ log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
466
+ model_params = end[-1]
467
+ return log, model_params
468
+ return _jax_wrapped_batched_policy_rollout
469
+
377
470
  def compile_rollouts(self, policy: Callable,
378
471
  n_steps: int,
379
472
  n_batch: int,
380
473
  check_constraints: bool=False,
381
474
  constraint_func: bool=False,
382
- init_params_constr: Dict[str, Any]={},
475
+ cache_path_info: bool=False,
383
476
  model_params_reduction: Callable=lambda x: x[0],
384
- cache_path_info: bool=False) -> Callable:
477
+ aux_constr: Dict[str, Any]={}) -> Callable:
385
478
  '''Compiles the current RDDL model into a JAX transition function that
386
479
  samples trajectories with a fixed horizon from a policy.
387
480
 
@@ -389,7 +482,8 @@ class JaxRDDLCompiler:
389
482
  - key is the PRNG key (used by a stochastic policy)
390
483
  - policy_params is a pytree of trainable policy weights
391
484
  - hyperparams is a pytree of (optional) fixed policy hyper-parameters
392
- - subs is the dictionary of current fluent tensor values
485
+ - fls is the dictionary of current fluent tensor values
486
+ - nfls is the dictionary of next step fluent tensor value
393
487
  - model_params is a dict of model hyperparameters.
394
488
 
395
489
  The returned value of the returned function is:
@@ -402,7 +496,7 @@ class JaxRDDLCompiler:
402
496
  - params is a pytree of trainable policy weights
403
497
  - hyperparams is a pytree of (optional) fixed policy hyper-parameters
404
498
  - step is the time index of the decision in the current rollout
405
- - states is a dict of tensors for the current observation.
499
+ - fls is a dict of fluent tensors for the current epoch.
406
500
 
407
501
  :param policy: a Jax compiled function for the policy as described above
408
502
  decision epoch, state dict, and an RNG key and returns an action dict
@@ -413,54 +507,16 @@ class JaxRDDLCompiler:
413
507
  returned log and does not raise an exception
414
508
  :param constraint_func: produces the h(s, a) constraint function
415
509
  in addition to the usual outputs
510
+ :param cache_path_info: whether to save full path traces as part of the log
416
511
  :param model_params_reduction: how to aggregate updated model_params across runs
417
512
  in the batch (defaults to selecting the first element's parameters in the batch)
418
- :param cache_path_info: whether to save full path traces as part of the log
419
513
  '''
420
- rddl = self.rddl
421
- jax_step_fn = self.compile_transition(
422
- check_constraints, constraint_func, init_params_constr, cache_path_info)
423
-
424
- # for POMDP only observ-fluents are assumed visible to the policy
425
- if rddl.observ_fluents:
426
- observed_vars = rddl.observ_fluents
427
- else:
428
- observed_vars = rddl.state_fluents
429
-
430
- # evaluate the step from the policy
431
- def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
432
- step, subs, model_params):
433
- states = {var: values
434
- for (var, values) in subs.items()
435
- if var in observed_vars}
436
- actions = policy(key, policy_params, hyperparams, step, states)
437
- key, subkey = random.split(key)
438
- return jax_step_fn(subkey, actions, subs, model_params)
439
-
440
- # do a batched step update from the policy
441
- def _jax_wrapped_batched_step_policy(carry, step):
442
- key, policy_params, hyperparams, subs, model_params = carry
443
- key, *subkeys = random.split(key, num=1 + n_batch)
444
- keys = jnp.asarray(subkeys)
445
- subs, log, model_params = jax.vmap(
446
- _jax_wrapped_single_step_policy,
447
- in_axes=(0, None, None, None, 0, None)
448
- )(keys, policy_params, hyperparams, step, subs, model_params)
449
- model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
450
- carry = (key, policy_params, hyperparams, subs, model_params)
451
- return carry, log
452
-
453
- # do a batched roll-out from the policy
454
- def _jax_wrapped_batched_rollout(key, policy_params, hyperparams,
455
- subs, model_params):
456
- start = (key, policy_params, hyperparams, subs, model_params)
457
- steps = jnp.arange(n_steps)
458
- end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
459
- log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
460
- model_params = end[-1]
461
- return log, model_params
462
-
463
- return _jax_wrapped_batched_rollout
514
+ jax_fn = self.compile_transition(
515
+ check_constraints, constraint_func, cache_path_info, aux_constr)
516
+ jax_fn = self._compile_policy_step(policy, jax_fn)
517
+ jax_fn = self._compile_batched_policy_step(jax_fn, n_batch, model_params_reduction)
518
+ jax_fn = self._compile_unrolled_policy_step(jax_fn, n_steps)
519
+ return jax_fn
464
520
 
465
521
  # ===========================================================================
466
522
  # error checks and prints
@@ -470,43 +526,59 @@ class JaxRDDLCompiler:
470
526
  '''Returns a dictionary containing the string representations of all
471
527
  Jax compiled expressions from the RDDL file.
472
528
  '''
473
- subs = self.init_values
474
- init_params = self.model_params
529
+ fls, nfls = self.split_fluent_nonfluent(self.init_values)
530
+ params = self.model_aux['params']
475
531
  key = jax.random.PRNGKey(42)
476
532
  printed = {
477
- 'cpfs': {name: str(jax.make_jaxpr(expr)(subs, init_params, key))
478
- for (name, expr) in self.cpfs.items()},
479
- 'reward': str(jax.make_jaxpr(self.reward)(subs, init_params, key)),
480
- 'invariants': [str(jax.make_jaxpr(expr)(subs, init_params, key))
481
- for expr in self.invariants],
482
- 'preconditions': [str(jax.make_jaxpr(expr)(subs, init_params, key))
483
- for expr in self.preconditions],
484
- 'terminations': [str(jax.make_jaxpr(expr)(subs, init_params, key))
485
- for expr in self.terminations]
533
+ 'cpfs': {
534
+ name: str(jax.make_jaxpr(expr)(fls, nfls, params, key))
535
+ for (name, expr) in self.cpfs.items()
536
+ },
537
+ 'reward': str(jax.make_jaxpr(self.reward)(fls, nfls, params, key)),
538
+ 'invariants': [
539
+ str(jax.make_jaxpr(expr)(fls, nfls, params, key))
540
+ for expr in self.invariants
541
+ ],
542
+ 'preconditions': [
543
+ str(jax.make_jaxpr(expr)(fls, nfls, params, key))
544
+ for expr in self.preconditions
545
+ ],
546
+ 'terminations': [
547
+ str(jax.make_jaxpr(expr)(fls, nfls, params, key))
548
+ for expr in self.terminations
549
+ ]
486
550
  }
487
551
  return printed
488
552
 
489
553
  def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
490
554
  '''Returns a dictionary of additional information about model parameters.'''
491
555
  result = {}
492
- for (id, value) in self.model_params.items():
493
- expr_id = int(str(id).split('_')[0])
494
- expr = self.traced.lookup(expr_id)
556
+ for (id, value) in self.model_aux['params'].items():
557
+ expr = self.traced.lookup(id)
495
558
  result[id] = {
496
- 'id': expr_id,
559
+ 'id': id,
497
560
  'rddl_op': ' '.join(expr.etype),
498
561
  'init_value': value
499
562
  }
500
563
  return result
501
564
 
565
+ def overriden_ops_info(self) -> Dict[str, Dict[str, List[int]]]:
566
+ '''Returns a dictionary of operations overriden by another class.'''
567
+ result = {}
568
+ for (id, class_) in self.model_aux['overriden'].items():
569
+ expr = self.traced.lookup(id)
570
+ rddl_op = ' '.join(expr.etype)
571
+ result.setdefault(class_, {}).setdefault(rddl_op, []).append(id)
572
+ return result
573
+
502
574
  @staticmethod
503
575
  def _check_valid_op(expr, valid_ops):
504
576
  etype, op = expr.etype
505
577
  if op not in valid_ops:
506
- valid_op_str = ','.join(valid_ops.keys())
578
+ valid_op_str = ','.join(valid_ops)
507
579
  raise RDDLNotImplementedError(
508
580
  f'{etype} operator {op} is not supported: '
509
- f'must be in {valid_op_str}.\n' + print_stack_trace(expr))
581
+ f'must be one of {valid_op_str}.\n' + print_stack_trace(expr))
510
582
 
511
583
  @staticmethod
512
584
  def _check_num_args(expr, required_args):
@@ -516,6 +588,15 @@ class JaxRDDLCompiler:
516
588
  raise RDDLInvalidNumberOfArgumentsError(
517
589
  f'{etype} operator {op} requires {required_args} arguments, '
518
590
  f'got {actual_args}.\n' + print_stack_trace(expr))
591
+
592
+ @staticmethod
593
+ def _check_num_args_min(expr, required_args):
594
+ actual_args = len(expr.args)
595
+ if actual_args < required_args:
596
+ etype, op = expr.etype
597
+ raise RDDLInvalidNumberOfArgumentsError(
598
+ f'{etype} operator {op} requires at least {required_args} arguments, '
599
+ f'got {actual_args}.\n' + print_stack_trace(expr))
519
600
 
520
601
  ERROR_CODES = {
521
602
  'NORMAL': 0,
@@ -580,8 +661,7 @@ class JaxRDDLCompiler:
580
661
  decomposes it into individual error codes.
581
662
  '''
582
663
  binary = reversed(bin(error)[2:])
583
- errors = [i for (i, c) in enumerate(binary) if c == '1']
584
- return errors
664
+ return [i for (i, c) in enumerate(binary) if c == '1']
585
665
 
586
666
  @staticmethod
587
667
  def get_error_messages(error: int) -> List[str]:
@@ -589,63 +669,59 @@ class JaxRDDLCompiler:
589
669
  decomposes it into error strings.
590
670
  '''
591
671
  codes = JaxRDDLCompiler.get_error_codes(error)
592
- messages = [JaxRDDLCompiler.INVERSE_ERROR_CODES[i] for i in codes]
593
- return messages
672
+ return [JaxRDDLCompiler.INVERSE_ERROR_CODES[i] for i in codes]
594
673
 
595
674
  # ===========================================================================
596
675
  # expression compilation
597
676
  # ===========================================================================
598
677
 
599
- def _jax(self, expr, init_params, dtype=None):
678
+ def _jax(self, expr, aux, dtype=None):
600
679
  etype, _ = expr.etype
601
680
  if etype == 'constant':
602
- jax_expr = self._jax_constant(expr, init_params)
681
+ jax_expr = self._jax_constant(expr, aux)
603
682
  elif etype == 'pvar':
604
- jax_expr = self._jax_pvar(expr, init_params)
683
+ jax_expr = self._jax_pvar(expr, aux)
605
684
  elif etype == 'arithmetic':
606
- jax_expr = self._jax_arithmetic(expr, init_params)
685
+ jax_expr = self._jax_arithmetic(expr, aux)
607
686
  elif etype == 'relational':
608
- jax_expr = self._jax_relational(expr, init_params)
687
+ jax_expr = self._jax_relational(expr, aux)
609
688
  elif etype == 'boolean':
610
- jax_expr = self._jax_logical(expr, init_params)
689
+ jax_expr = self._jax_logical(expr, aux)
611
690
  elif etype == 'aggregation':
612
- jax_expr = self._jax_aggregation(expr, init_params)
691
+ jax_expr = self._jax_aggregation(expr, aux)
613
692
  elif etype == 'func':
614
- jax_expr = self._jax_functional(expr, init_params)
693
+ jax_expr = self._jax_function(expr, aux)
615
694
  elif etype == 'pyfunc':
616
- jax_expr = self._jax_pyfunc(expr, init_params)
695
+ jax_expr = self._jax_pyfunc(expr, aux)
617
696
  elif etype == 'control':
618
- jax_expr = self._jax_control(expr, init_params)
697
+ jax_expr = self._jax_control(expr, aux)
619
698
  elif etype == 'randomvar':
620
- jax_expr = self._jax_random(expr, init_params)
699
+ jax_expr = self._jax_random(expr, aux)
621
700
  elif etype == 'randomvector':
622
- jax_expr = self._jax_random_vector(expr, init_params)
701
+ jax_expr = self._jax_random_vector(expr, aux)
623
702
  elif etype == 'matrix':
624
- jax_expr = self._jax_matrix(expr, init_params)
703
+ jax_expr = self._jax_matrix(expr, aux)
625
704
  else:
626
705
  raise RDDLNotImplementedError(
627
- f'Internal error: expression type {expr} is not supported.\n' +
628
- print_stack_trace(expr))
706
+ f'Expression type {expr} is not supported.\n' + print_stack_trace(expr))
629
707
 
630
708
  # force type cast of tensor as required by caller
631
709
  if dtype is not None:
632
710
  jax_expr = self._jax_cast(jax_expr, dtype)
633
-
634
711
  return jax_expr
635
712
 
636
713
  def _jax_cast(self, jax_expr, dtype):
637
714
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
638
-
639
- def _jax_wrapped_cast(x, params, key):
640
- val, key, err, params = jax_expr(x, params, key)
715
+
716
+ def _jax_wrapped_cast(fls, nfls, params, key):
717
+ val, key, err, params = jax_expr(fls, nfls, params, key)
641
718
  sample = jnp.asarray(val, dtype=dtype)
642
719
  invalid_cast = jnp.logical_and(
643
720
  jnp.logical_not(jnp.can_cast(val, dtype)),
644
721
  jnp.any(sample != val)
645
722
  )
646
- err |= (invalid_cast * ERR)
723
+ err = err | (invalid_cast * ERR)
647
724
  return sample, key, err, params
648
-
649
725
  return _jax_wrapped_cast
650
726
 
651
727
  def _fix_dtype(self, value):
@@ -654,34 +730,33 @@ class JaxRDDLCompiler:
654
730
  return self.INT
655
731
  elif jnp.issubdtype(dtype, jnp.floating):
656
732
  return self.REAL
657
- elif jnp.issubdtype(dtype, jnp.bool_) or jnp.issubdtype(dtype, bool):
658
- return bool
733
+ elif jnp.issubdtype(dtype, jnp.bool_):
734
+ return jnp.bool_
659
735
  else:
660
- raise TypeError(f'Invalid type {dtype} of {value}.')
736
+ raise TypeError(f'dtype {dtype} of {value} is not valid.')
661
737
 
662
738
  # ===========================================================================
663
739
  # leaves
664
740
  # ===========================================================================
665
741
 
666
- def _jax_constant(self, expr, init_params):
742
+ def _jax_constant(self, expr, aux):
667
743
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
668
744
  cached_value = self.traced.cached_sim_info(expr)
745
+ dtype = self._fix_dtype(cached_value)
669
746
 
670
- def _jax_wrapped_constant(x, params, key):
671
- sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
747
+ def _jax_wrapped_constant(fls, nfls, params, key):
748
+ sample = jnp.asarray(cached_value, dtype=dtype)
672
749
  return sample, key, NORMAL, params
673
-
674
750
  return _jax_wrapped_constant
675
751
 
676
- def _jax_pvar_slice(self, _slice):
752
+ def _jax_pvar_slice(self, slice_):
677
753
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
678
754
 
679
- def _jax_wrapped_pvar_slice(x, params, key):
680
- return _slice, key, NORMAL, params
681
-
755
+ def _jax_wrapped_pvar_slice(fls, nfls, params, key):
756
+ return slice_, key, NORMAL, params
682
757
  return _jax_wrapped_pvar_slice
683
758
 
684
- def _jax_pvar(self, expr, init_params):
759
+ def _jax_pvar(self, expr, aux):
685
760
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
686
761
  var, pvars = expr.args
687
762
  is_value, cached_info = self.traced.cached_sim_info(expr)
@@ -690,21 +765,19 @@ class JaxRDDLCompiler:
690
765
  # boundary case: domain object is converted to canonical integer index
691
766
  if is_value:
692
767
  cached_value = cached_info
693
-
694
- def _jax_wrapped_object(x, params, key):
695
- sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
696
- return sample, key, NORMAL, params
768
+ dtype = self._fix_dtype(cached_value)
697
769
 
770
+ def _jax_wrapped_object(fls, nfls, params, key):
771
+ sample = jnp.asarray(cached_value, dtype=dtype)
772
+ return sample, key, NORMAL, params
698
773
  return _jax_wrapped_object
699
774
 
700
775
  # boundary case: no shape information (e.g. scalar pvar)
701
776
  elif cached_info is None:
702
-
703
- def _jax_wrapped_pvar_scalar(x, params, key):
704
- value = x[var]
777
+ def _jax_wrapped_pvar_scalar(fls, nfls, params, key):
778
+ value = fls[var] if var in fls else nfls[var]
705
779
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
706
780
  return sample, key, NORMAL, params
707
-
708
781
  return _jax_wrapped_pvar_scalar
709
782
 
710
783
  # must slice and/or reshape value tensor to match free variables
@@ -713,34 +786,29 @@ class JaxRDDLCompiler:
713
786
 
714
787
  # compile nested expressions
715
788
  if slices and op_code == RDDLObjectsTracer.NUMPY_OP_CODE.NESTED_SLICE:
789
+ jax_nested_expr = [
790
+ (self._jax(arg, aux) if slice_ is None else self._jax_pvar_slice(slice_))
791
+ for (arg, slice_) in zip(pvars, slices)
792
+ ]
716
793
 
717
- jax_nested_expr = [(self._jax(arg, init_params)
718
- if _slice is None
719
- else self._jax_pvar_slice(_slice))
720
- for (arg, _slice) in zip(pvars, slices)]
721
-
722
- def _jax_wrapped_pvar_tensor_nested(x, params, key):
794
+ def _jax_wrapped_pvar_tensor_nested(fls, nfls, params, key):
723
795
  error = NORMAL
724
- value = x[var]
796
+ value = fls[var] if var in fls else nfls[var]
725
797
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
726
- new_slices = [None] * len(jax_nested_expr)
727
- for (i, jax_expr) in enumerate(jax_nested_expr):
728
- new_slice, key, err, params = jax_expr(x, params, key)
729
- if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
730
- new_slice = jnp.asarray(new_slice, dtype=self.INT)
731
- new_slices[i] = new_slice
732
- error |= err
733
- new_slices = tuple(new_slices)
734
- sample = sample[new_slices]
798
+ new_slices = []
799
+ for jax_expr in jax_nested_expr:
800
+ new_slice, key, err, params = jax_expr(fls, nfls, params, key)
801
+ new_slice = jnp.asarray(new_slice, dtype=self.INT)
802
+ new_slices.append(new_slice)
803
+ error = error | err
804
+ sample = sample[tuple(new_slices)]
735
805
  return sample, key, error, params
736
-
737
806
  return _jax_wrapped_pvar_tensor_nested
738
807
 
739
808
  # tensor variable but no nesting
740
809
  else:
741
-
742
- def _jax_wrapped_pvar_tensor_non_nested(x, params, key):
743
- value = x[var]
810
+ def _jax_wrapped_pvar_tensor_non_nested(fls, nfls, params, key):
811
+ value = fls[var] if var in fls else nfls[var]
744
812
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
745
813
  if slices:
746
814
  sample = sample[slices]
@@ -752,190 +820,408 @@ class JaxRDDLCompiler:
752
820
  elif op_code == RDDLObjectsTracer.NUMPY_OP_CODE.TRANSPOSE:
753
821
  sample = jnp.transpose(sample, axes=op_args)
754
822
  return sample, key, NORMAL, params
755
-
756
823
  return _jax_wrapped_pvar_tensor_non_nested
757
824
 
758
825
  # ===========================================================================
759
- # mathematical
826
+ # boilerplate helper functions
760
827
  # ===========================================================================
761
828
 
762
829
  def _jax_unary(self, jax_expr, jax_op, at_least_int=False, check_dtype=None):
763
830
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
764
831
 
765
- def _jax_wrapped_unary_op(x, params, key):
766
- sample, key, err, params = jax_expr(x, params, key)
832
+ def _jax_wrapped_unary_op(fls, nfls, params, key):
833
+ sample, key, err, params = jax_expr(fls, nfls, params, key)
767
834
  if at_least_int:
768
835
  sample = self.ONE * sample
769
- sample, params = jax_op(sample, params)
770
836
  if check_dtype is not None:
771
837
  invalid_cast = jnp.logical_not(jnp.can_cast(sample, check_dtype))
772
- err |= (invalid_cast * ERR)
838
+ err = err | (invalid_cast * ERR)
839
+ sample = jax_op(sample)
773
840
  return sample, key, err, params
774
-
775
841
  return _jax_wrapped_unary_op
776
842
 
777
843
  def _jax_binary(self, jax_lhs, jax_rhs, jax_op, at_least_int=False, check_dtype=None):
778
844
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
779
845
 
780
- def _jax_wrapped_binary_op(x, params, key):
781
- sample1, key, err1, params = jax_lhs(x, params, key)
782
- sample2, key, err2, params = jax_rhs(x, params, key)
846
+ def _jax_wrapped_binary_op(fls, nfls, params, key):
847
+ sample1, key, err1, params = jax_lhs(fls, nfls, params, key)
848
+ sample2, key, err2, params = jax_rhs(fls, nfls, params, key)
783
849
  if at_least_int:
784
850
  sample1 = self.ONE * sample1
785
851
  sample2 = self.ONE * sample2
786
- sample, params = jax_op(sample1, sample2, params)
852
+ sample = jax_op(sample1, sample2)
787
853
  err = err1 | err2
788
854
  if check_dtype is not None:
789
855
  invalid_cast = jnp.logical_not(jnp.logical_and(
790
856
  jnp.can_cast(sample1, check_dtype),
791
857
  jnp.can_cast(sample2, check_dtype))
792
858
  )
793
- err |= (invalid_cast * ERR)
859
+ err = err | (invalid_cast * ERR)
794
860
  return sample, key, err, params
795
-
796
861
  return _jax_wrapped_binary_op
797
862
 
798
- def _jax_arithmetic(self, expr, init_params):
799
- _, op = expr.etype
800
-
801
- # if expression is non-fluent, always use the exact operation
802
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
803
- valid_ops = self.EXACT_OPS['arithmetic']
804
- negative_op = self.EXACT_OPS['negative']
805
- else:
806
- valid_ops = self.OPS['arithmetic']
807
- negative_op = self.OPS['negative']
808
- JaxRDDLCompiler._check_valid_op(expr, valid_ops)
809
-
810
- # recursively compile arguments
811
- args = expr.args
812
- n = len(args)
813
- if n == 1 and op == '-':
814
- arg, = args
815
- jax_expr = self._jax(arg, init_params)
816
- jax_op = negative_op(expr.id, init_params)
817
- return self._jax_unary(jax_expr, jax_op, at_least_int=True)
818
-
819
- elif n == 2 or (n >= 2 and op in {'*', '+'}):
820
- jax_exprs = [self._jax(arg, init_params) for arg in args]
821
- result = jax_exprs[0]
822
- for (i, jax_rhs) in enumerate(jax_exprs[1:]):
823
- jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
824
- result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
825
- return result
826
-
863
+ def _jax_unary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
864
+ JaxRDDLCompiler._check_num_args(expr, 1)
865
+ arg, = expr.args
866
+ jax_expr = self._jax(arg, aux)
867
+ return self._jax_unary(
868
+ jax_expr, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
869
+
870
+ def _jax_binary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
827
871
  JaxRDDLCompiler._check_num_args(expr, 2)
872
+ lhs, rhs = expr.args
873
+ jax_lhs = self._jax(lhs, aux)
874
+ jax_rhs = self._jax(rhs, aux)
875
+ return self._jax_binary(
876
+ jax_lhs, jax_rhs, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
828
877
 
829
- def _jax_relational(self, expr, init_params):
878
+ def _jax_nary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
879
+ JaxRDDLCompiler._check_num_args_min(expr, 2)
880
+ args = expr.args
881
+ jax_exprs = [self._jax(arg, aux) for arg in args]
882
+ result = jax_exprs[0]
883
+ for jax_rhs in jax_exprs[1:]:
884
+ result = self._jax_binary(
885
+ result, jax_rhs, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
886
+ return result
887
+
888
+ # ===========================================================================
889
+ # arithmetic
890
+ # ===========================================================================
891
+
892
+ def _jax_arithmetic(self, expr, aux):
893
+ JaxRDDLCompiler._check_valid_op(expr, {'-', '+', '*', '/'})
830
894
  _, op = expr.etype
831
-
832
- # if expression is non-fluent, always use the exact operation
833
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
834
- valid_ops = self.EXACT_OPS['relational']
835
- else:
836
- valid_ops = self.OPS['relational']
837
- JaxRDDLCompiler._check_valid_op(expr, valid_ops)
838
-
839
- # recursively compile arguments
840
- JaxRDDLCompiler._check_num_args(expr, 2)
841
- lhs, rhs = expr.args
842
- jax_lhs = self._jax(lhs, init_params)
843
- jax_rhs = self._jax(rhs, init_params)
844
- jax_op = valid_ops[op](expr.id, init_params)
845
- return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
846
-
847
- def _jax_logical(self, expr, init_params):
895
+ if op == '-':
896
+ if len(expr.args) == 1:
897
+ return self._jax_negate(expr, aux)
898
+ else:
899
+ return self._jax_subtract(expr, aux)
900
+ elif op == '/':
901
+ return self._jax_divide(expr, aux)
902
+ elif op == '+':
903
+ return self._jax_add(expr, aux)
904
+ elif op == '*':
905
+ return self._jax_multiply(expr, aux)
906
+
907
+ def _jax_negate(self, expr, aux):
908
+ return self._jax_unary_helper(expr, aux, jnp.negative, at_least_int=True)
909
+
910
+ def _jax_add(self, expr, aux):
911
+ return self._jax_nary_helper(expr, aux, jnp.add, at_least_int=True)
912
+
913
+ def _jax_subtract(self, expr, aux):
914
+ return self._jax_binary_helper(expr, aux, jnp.subtract, at_least_int=True)
915
+
916
+ def _jax_multiply(self, expr, aux):
917
+ return self._jax_nary_helper(expr, aux, jnp.multiply, at_least_int=True)
918
+
919
+ def _jax_divide(self, expr, aux):
920
+ return self._jax_binary_helper(expr, aux, jnp.divide, at_least_int=True)
921
+
922
+ # ===========================================================================
923
+ # relational
924
+ # ===========================================================================
925
+
926
+ def _jax_relational(self, expr, aux):
927
+ JaxRDDLCompiler._check_valid_op(expr, {'>=', '<=', '>', '<', '==', '~='})
848
928
  _, op = expr.etype
929
+ if op == '>=':
930
+ return self._jax_greater_equal(expr, aux)
931
+ elif op == '<=':
932
+ return self._jax_less_equal(expr, aux)
933
+ elif op == '>':
934
+ return self._jax_greater(expr, aux)
935
+ elif op == '<':
936
+ return self._jax_less(expr, aux)
937
+ elif op == '==':
938
+ return self._jax_equal(expr, aux)
939
+ elif op == '~=':
940
+ return self._jax_not_equal(expr, aux)
941
+
942
+ def _jax_greater_equal(self, expr, aux):
943
+ return self._jax_binary_helper(expr, aux, jnp.greater_equal, at_least_int=True)
849
944
 
850
- # if expression is non-fluent, always use the exact operation
851
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
852
- valid_ops = self.EXACT_OPS['logical']
853
- logical_not_op = self.EXACT_OPS['logical_not']
854
- else:
855
- valid_ops = self.OPS['logical']
856
- logical_not_op = self.OPS['logical_not']
857
- JaxRDDLCompiler._check_valid_op(expr, valid_ops)
858
-
859
- # recursively compile arguments
860
- args = expr.args
861
- n = len(args)
862
- if n == 1 and op == '~':
863
- arg, = args
864
- jax_expr = self._jax(arg, init_params)
865
- jax_op = logical_not_op(expr.id, init_params)
866
- return self._jax_unary(jax_expr, jax_op, check_dtype=bool)
867
-
868
- elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
869
- jax_exprs = [self._jax(arg, init_params) for arg in args]
870
- result = jax_exprs[0]
871
- for i, jax_rhs in enumerate(jax_exprs[1:]):
872
- jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
873
- result = self._jax_binary(result, jax_rhs, jax_op, check_dtype=bool)
874
- return result
875
-
876
- JaxRDDLCompiler._check_num_args(expr, 2)
945
+ def _jax_less_equal(self, expr, aux):
946
+ return self._jax_binary_helper(expr, aux, jnp.less_equal, at_least_int=True)
947
+
948
+ def _jax_greater(self, expr, aux):
949
+ return self._jax_binary_helper(expr, aux, jnp.greater, at_least_int=True)
877
950
 
878
- def _jax_aggregation(self, expr, init_params):
879
- ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
951
+ def _jax_less(self, expr, aux):
952
+ return self._jax_binary_helper(expr, aux, jnp.less, at_least_int=True)
953
+
954
+ def _jax_equal(self, expr, aux):
955
+ return self._jax_binary_helper(expr, aux, jnp.equal, at_least_int=True)
956
+
957
+ def _jax_not_equal(self, expr, aux):
958
+ return self._jax_binary_helper(expr, aux, jnp.not_equal, at_least_int=True)
959
+
960
+ # ===========================================================================
961
+ # logical
962
+ # ===========================================================================
963
+
964
+ def _jax_logical(self, expr, aux):
965
+ JaxRDDLCompiler._check_valid_op(expr, {'^', '&', '|', '~', '=>', '<=>'})
880
966
  _, op = expr.etype
881
-
882
- # if expression is non-fluent, always use the exact operation
883
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
884
- valid_ops = self.EXACT_OPS['aggregation']
885
- else:
886
- valid_ops = self.OPS['aggregation']
887
- JaxRDDLCompiler._check_valid_op(expr, valid_ops)
888
- is_floating = op not in self.AGGREGATION_BOOL
889
-
890
- # recursively compile arguments
891
- * _, arg = expr.args
892
- _, axes = self.traced.cached_sim_info(expr)
893
- jax_expr = self._jax(arg, init_params)
894
- jax_op = valid_ops[op](expr.id, init_params)
895
-
896
- def _jax_wrapped_aggregation(x, params, key):
897
- sample, key, err, params = jax_expr(x, params, key)
898
- if is_floating:
899
- sample = self.ONE * sample
967
+ if op == '~':
968
+ if len(expr.args) == 1:
969
+ return self._jax_not(expr, aux)
900
970
  else:
901
- invalid_cast = jnp.logical_not(jnp.can_cast(sample, bool))
902
- err |= (invalid_cast * ERR)
903
- sample, params = jax_op(sample, axis=axes, params=params)
904
- return sample, key, err, params
905
-
906
- return _jax_wrapped_aggregation
907
-
908
- def _jax_functional(self, expr, init_params):
971
+ return self._jax_xor(expr, aux)
972
+ elif op == '^' or op == '&':
973
+ return self._jax_and(expr, aux)
974
+ elif op == '|':
975
+ return self._jax_or(expr, aux)
976
+ elif op == '=>':
977
+ return self._jax_implies(expr, aux)
978
+ elif op == '<=>':
979
+ return self._jax_equiv(expr, aux)
980
+
981
+ def _jax_not(self, expr, aux):
982
+ return self._jax_unary_helper(expr, aux, jnp.logical_not, check_dtype=jnp.bool_)
983
+
984
+ def _jax_and(self, expr, aux):
985
+ return self._jax_nary_helper(expr, aux, jnp.logical_and, check_dtype=jnp.bool_)
986
+
987
+ def _jax_or(self, expr, aux):
988
+ return self._jax_nary_helper(expr, aux, jnp.logical_or, check_dtype=jnp.bool_)
989
+
990
+ def _jax_xor(self, expr, aux):
991
+ return self._jax_binary_helper(expr, aux, jnp.logical_xor, check_dtype=jnp.bool_)
992
+
993
+ def _jax_implies(self, expr, aux):
994
+ def implies_op(x, y):
995
+ return jnp.logical_or(jnp.logical_not(x), y)
996
+ return self._jax_binary_helper(expr, aux, implies_op, check_dtype=jnp.bool_)
997
+
998
+ def _jax_equiv(self, expr, aux):
999
+ return self._jax_binary_helper(expr, aux, jnp.equal, check_dtype=jnp.bool_)
1000
+
1001
+ # ===========================================================================
1002
+ # aggregation
1003
+ # ===========================================================================
1004
+
1005
+ def _jax_aggregation(self, expr, aux):
1006
+ JaxRDDLCompiler._check_valid_op(expr, {'sum', 'avg', 'prod', 'minimum', 'maximum',
1007
+ 'forall', 'exists', 'argmin', 'argmax'})
909
1008
  _, op = expr.etype
910
-
911
- # if expression is non-fluent, always use the exact operation
912
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
913
- unary_ops = self.EXACT_OPS['unary']
914
- binary_ops = self.EXACT_OPS['binary']
915
- else:
916
- unary_ops = self.OPS['unary']
917
- binary_ops = self.OPS['binary']
918
-
919
- # recursively compile arguments
920
- if op in unary_ops:
921
- JaxRDDLCompiler._check_num_args(expr, 1)
922
- arg, = expr.args
923
- jax_expr = self._jax(arg, init_params)
924
- jax_op = unary_ops[op](expr.id, init_params)
925
- return self._jax_unary(jax_expr, jax_op, at_least_int=True)
926
-
927
- elif op in binary_ops:
928
- JaxRDDLCompiler._check_num_args(expr, 2)
929
- lhs, rhs = expr.args
930
- jax_lhs = self._jax(lhs, init_params)
931
- jax_rhs = self._jax(rhs, init_params)
932
- jax_op = binary_ops[op](expr.id, init_params)
933
- return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
934
-
935
- raise RDDLNotImplementedError(
936
- f'Function {op} is not supported.\n' + print_stack_trace(expr))
937
-
938
- def _jax_pyfunc(self, expr, init_params):
1009
+ if op == 'sum':
1010
+ return self._jax_sum(expr, aux)
1011
+ elif op == 'avg':
1012
+ return self._jax_avg(expr, aux)
1013
+ elif op == 'prod':
1014
+ return self._jax_prod(expr, aux)
1015
+ elif op == 'minimum':
1016
+ return self._jax_minimum(expr, aux)
1017
+ elif op == 'maximum':
1018
+ return self._jax_maximum(expr, aux)
1019
+ elif op == 'forall':
1020
+ return self._jax_forall(expr, aux)
1021
+ elif op == 'exists':
1022
+ return self._jax_exists(expr, aux)
1023
+ elif op == 'argmin':
1024
+ return self._jax_argmin(expr, aux)
1025
+ elif op == 'argmax':
1026
+ return self._jax_argmax(expr, aux)
1027
+
1028
+ def _jax_aggregation_helper(self, expr, aux, jax_op, is_bool=False):
1029
+ arg = expr.args[-1]
1030
+ _, axes = self.traced.cached_sim_info(expr)
1031
+ jax_expr = self._jax(arg, aux)
1032
+ return self._jax_unary(
1033
+ jax_expr,
1034
+ jax_op=partial(jax_op, axis=axes),
1035
+ at_least_int=not is_bool,
1036
+ check_dtype=jnp.bool_ if is_bool else None
1037
+ )
1038
+
1039
+ def _jax_sum(self, expr, aux):
1040
+ return self._jax_aggregation_helper(expr, aux, jnp.sum)
1041
+
1042
+ def _jax_avg(self, expr, aux):
1043
+ return self._jax_aggregation_helper(expr, aux, jnp.mean)
1044
+
1045
+ def _jax_prod(self, expr, aux):
1046
+ return self._jax_aggregation_helper(expr, aux, jnp.prod)
1047
+
1048
+ def _jax_minimum(self, expr, aux):
1049
+ return self._jax_aggregation_helper(expr, aux, jnp.min)
1050
+
1051
+ def _jax_maximum(self, expr, aux):
1052
+ return self._jax_aggregation_helper(expr, aux, jnp.max)
1053
+
1054
+ def _jax_forall(self, expr, aux):
1055
+ return self._jax_aggregation_helper(expr, aux, jnp.all, is_bool=True)
1056
+
1057
+ def _jax_exists(self, expr, aux):
1058
+ return self._jax_aggregation_helper(expr, aux, jnp.any, is_bool=True)
1059
+
1060
+ def _jax_argmin(self, expr, aux):
1061
+ return self._jax_aggregation_helper(expr, aux, jnp.argmin)
1062
+
1063
+ def _jax_argmax(self, expr, aux):
1064
+ return self._jax_aggregation_helper(expr, aux, jnp.argmax)
1065
+
1066
+ # ===========================================================================
1067
+ # function
1068
+ # ===========================================================================
1069
+
1070
+ def _jax_function(self, expr, aux):
1071
+ JaxRDDLCompiler._check_valid_op(expr, {'abs', 'sgn', 'round', 'floor', 'ceil',
1072
+ 'cos', 'sin', 'tan', 'acos', 'asin', 'atan',
1073
+ 'cosh', 'sinh', 'tanh', 'exp', 'ln', 'sqrt',
1074
+ 'lngamma', 'gamma',
1075
+ 'div', 'mod', 'fmod', 'min', 'max',
1076
+ 'pow', 'log', 'hypot'})
1077
+ _, op = expr.etype
1078
+
1079
+ # unary functions
1080
+ if op == 'abs':
1081
+ return self._jax_abs(expr, aux)
1082
+ elif op == 'sgn':
1083
+ return self._jax_sgn(expr, aux)
1084
+ elif op == 'round':
1085
+ return self._jax_round(expr, aux)
1086
+ elif op == 'floor':
1087
+ return self._jax_floor(expr, aux)
1088
+ elif op == 'ceil':
1089
+ return self._jax_ceil(expr, aux)
1090
+ elif op == 'cos':
1091
+ return self._jax_cos(expr, aux)
1092
+ elif op == 'sin':
1093
+ return self._jax_sin(expr, aux)
1094
+ elif op == 'tan':
1095
+ return self._jax_tan(expr, aux)
1096
+ elif op == 'acos':
1097
+ return self._jax_acos(expr, aux)
1098
+ elif op == 'asin':
1099
+ return self._jax_asin(expr, aux)
1100
+ elif op == 'atan':
1101
+ return self._jax_atan(expr, aux)
1102
+ elif op == 'cosh':
1103
+ return self._jax_cosh(expr, aux)
1104
+ elif op == 'sinh':
1105
+ return self._jax_sinh(expr, aux)
1106
+ elif op == 'tanh':
1107
+ return self._jax_tanh(expr, aux)
1108
+ elif op == 'exp':
1109
+ return self._jax_exp(expr, aux)
1110
+ elif op == 'ln':
1111
+ return self._jax_ln(expr, aux)
1112
+ elif op == 'sqrt':
1113
+ return self._jax_sqrt(expr, aux)
1114
+ elif op == 'lngamma':
1115
+ return self._jax_lngamma(expr, aux)
1116
+ elif op == 'gamma':
1117
+ return self._jax_gamma(expr, aux)
1118
+
1119
+ # binary functions
1120
+ elif op == 'div':
1121
+ return self._jax_div(expr, aux)
1122
+ elif op == 'mod':
1123
+ return self._jax_mod(expr, aux)
1124
+ elif op == 'fmod':
1125
+ return self._jax_fmod(expr, aux)
1126
+ elif op == 'min':
1127
+ return self._jax_min(expr, aux)
1128
+ elif op == 'max':
1129
+ return self._jax_max(expr, aux)
1130
+ elif op == 'pow':
1131
+ return self._jax_pow(expr, aux)
1132
+ elif op == 'log':
1133
+ return self._jax_log(expr, aux)
1134
+ elif op == 'hypot':
1135
+ return self._jax_hypot(expr, aux)
1136
+
1137
+ def _jax_abs(self, expr, aux):
1138
+ return self._jax_unary_helper(expr, aux, jnp.abs, at_least_int=True)
1139
+
1140
+ def _jax_sgn(self, expr, aux):
1141
+ return self._jax_unary_helper(expr, aux, jnp.sign, at_least_int=True)
1142
+
1143
+ def _jax_round(self, expr, aux):
1144
+ return self._jax_unary_helper(expr, aux, jnp.round, at_least_int=True)
1145
+
1146
+ def _jax_floor(self, expr, aux):
1147
+ return self._jax_unary_helper(expr, aux, jnp.floor, at_least_int=True)
1148
+
1149
+ def _jax_ceil(self, expr, aux):
1150
+ return self._jax_unary_helper(expr, aux, jnp.ceil, at_least_int=True)
1151
+
1152
+ def _jax_cos(self, expr, aux):
1153
+ return self._jax_unary_helper(expr, aux, jnp.cos, at_least_int=True)
1154
+
1155
+ def _jax_sin(self, expr, aux):
1156
+ return self._jax_unary_helper(expr, aux, jnp.sin, at_least_int=True)
1157
+
1158
+ def _jax_tan(self, expr, aux):
1159
+ return self._jax_unary_helper(expr, aux, jnp.tan, at_least_int=True)
1160
+
1161
+ def _jax_acos(self, expr, aux):
1162
+ return self._jax_unary_helper(expr, aux, jnp.arccos, at_least_int=True)
1163
+
1164
+ def _jax_asin(self, expr, aux):
1165
+ return self._jax_unary_helper(expr, aux, jnp.arcsin, at_least_int=True)
1166
+
1167
+ def _jax_atan(self, expr, aux):
1168
+ return self._jax_unary_helper(expr, aux, jnp.arctan, at_least_int=True)
1169
+
1170
+ def _jax_cosh(self, expr, aux):
1171
+ return self._jax_unary_helper(expr, aux, jnp.cosh, at_least_int=True)
1172
+
1173
+ def _jax_sinh(self, expr, aux):
1174
+ return self._jax_unary_helper(expr, aux, jnp.sinh, at_least_int=True)
1175
+
1176
+ def _jax_tanh(self, expr, aux):
1177
+ return self._jax_unary_helper(expr, aux, jnp.tanh, at_least_int=True)
1178
+
1179
+ def _jax_exp(self, expr, aux):
1180
+ return self._jax_unary_helper(expr, aux, jnp.exp, at_least_int=True)
1181
+
1182
+ def _jax_ln(self, expr, aux):
1183
+ return self._jax_unary_helper(expr, aux, jnp.ln, at_least_int=True)
1184
+
1185
+ def _jax_sqrt(self, expr, aux):
1186
+ return self._jax_unary_helper(expr, aux, jnp.sqrt, at_least_int=True)
1187
+
1188
+ def _jax_lngamma(self, expr, aux):
1189
+ return self._jax_unary_helper(expr, aux, scipy.special.gammaln, at_least_int=True)
1190
+
1191
+ def _jax_gamma(self, expr, aux):
1192
+ return self._jax_unary_helper(expr, aux, scipy.special.gamma, at_least_int=True)
1193
+
1194
+ def _jax_div(self, expr, aux):
1195
+ return self._jax_binary_helper(expr, aux, jnp.floor_divide, at_least_int=True)
1196
+
1197
+ def _jax_mod(self, expr, aux):
1198
+ return self._jax_binary_helper(expr, aux, jnp.mod, at_least_int=True)
1199
+
1200
+ def _jax_fmod(self, expr, aux):
1201
+ return self._jax_binary_helper(expr, aux, jnp.mod, at_least_int=True)
1202
+
1203
+ def _jax_min(self, expr, aux):
1204
+ return self._jax_binary_helper(expr, aux, jnp.minimum, at_least_int=True)
1205
+
1206
+ def _jax_max(self, expr, aux):
1207
+ return self._jax_binary_helper(expr, aux, jnp.maximum, at_least_int=True)
1208
+
1209
+ def _jax_pow(self, expr, aux):
1210
+ return self._jax_binary_helper(expr, aux, jnp.power, at_least_int=True)
1211
+
1212
+ def _jax_log(self, expr, aux):
1213
+ def log_op(x, y):
1214
+ return jnp.log(x) / jnp.log(y)
1215
+ return self._jax_binary_helper(expr, aux, log_op, at_least_int=True)
1216
+
1217
+ def _jax_hypot(self, expr, aux):
1218
+ return self._jax_binary_helper(expr, aux, jnp.hypot, at_least_int=True)
1219
+
1220
+ # ===========================================================================
1221
+ # external function
1222
+ # ===========================================================================
1223
+
1224
+ def _jax_pyfunc(self, expr, aux):
939
1225
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
940
1226
 
941
1227
  # get the Python function by name
@@ -957,25 +1243,21 @@ class JaxRDDLCompiler:
957
1243
  require_dims = self.rddl.object_counts(captured_types)
958
1244
 
959
1245
  # compile the inputs to the function
960
- jax_inputs = [self._jax(arg, init_params) for arg in args]
1246
+ jax_inputs = [self._jax(arg, aux) for arg in args]
961
1247
 
962
1248
  # compile the function evaluation function
963
- def _jax_wrapped_external_function(x, params, key):
1249
+ def _jax_wrapped_external_function(fls, nfls, params, key):
964
1250
 
965
1251
  # evaluate inputs to the function
966
1252
  # first dimensions are non-captured vars in outer scope followed by all the _
967
1253
  error = NORMAL
968
1254
  flat_samples = []
969
1255
  for jax_expr in jax_inputs:
970
- sample, key, err, params = jax_expr(x, params, key)
971
- shape = jnp.shape(sample)
972
- first_dim = 1
973
- for dim in shape[:num_free_vars]:
974
- first_dim *= dim
975
- new_shape = (first_dim,) + shape[num_free_vars:]
1256
+ sample, key, err, params = jax_expr(fls, nfls, params, key)
1257
+ new_shape = (-1,) + jnp.shape(sample)[num_free_vars:]
976
1258
  flat_sample = jnp.reshape(sample, new_shape)
977
1259
  flat_samples.append(flat_sample)
978
- error |= err
1260
+ error = error | err
979
1261
 
980
1262
  # now all the inputs have dimensions equal to (k,) + the number of _ occurences
981
1263
  # k is the number of possible non-captured object combinations
@@ -986,7 +1268,8 @@ class JaxRDDLCompiler:
986
1268
  if not isinstance(sample, jnp.ndarray):
987
1269
  raise ValueError(
988
1270
  f'Output of external Python function <{pyfunc_name}> '
989
- f'is not a JAX array.\n' + print_stack_trace(expr))
1271
+ f'is not a JAX array.\n' + print_stack_trace(expr)
1272
+ )
990
1273
 
991
1274
  pyfunc_dims = jnp.shape(sample)[1:]
992
1275
  if len(require_dims) != len(pyfunc_dims):
@@ -994,14 +1277,16 @@ class JaxRDDLCompiler:
994
1277
  f'External Python function <{pyfunc_name}> returned array with '
995
1278
  f'{len(pyfunc_dims)} dimensions, which does not match the '
996
1279
  f'number of captured parameter(s) {len(require_dims)}.\n' +
997
- print_stack_trace(expr))
1280
+ print_stack_trace(expr)
1281
+ )
998
1282
  for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
999
1283
  if require_dim != actual_dim:
1000
1284
  raise ValueError(
1001
1285
  f'External Python function <{pyfunc_name}> returned array with '
1002
1286
  f'{actual_dim} elements for captured parameter <{param}>, '
1003
1287
  f'which does not match the number of objects {require_dim}.\n' +
1004
- print_stack_trace(expr))
1288
+ print_stack_trace(expr)
1289
+ )
1005
1290
 
1006
1291
  # unravel the combinations k back into their original dimensions
1007
1292
  sample = jnp.reshape(sample, free_dims + pyfunc_dims)
@@ -1017,111 +1302,75 @@ class JaxRDDLCompiler:
1017
1302
  # control flow
1018
1303
  # ===========================================================================
1019
1304
 
1020
- def _jax_control(self, expr, init_params):
1305
+ def _jax_control(self, expr, aux):
1306
+ JaxRDDLCompiler._check_valid_op(expr, {'if', 'switch'})
1021
1307
  _, op = expr.etype
1022
1308
  if op == 'if':
1023
- return self._jax_if(expr, init_params)
1309
+ return self._jax_if(expr, aux)
1024
1310
  elif op == 'switch':
1025
- return self._jax_switch(expr, init_params)
1026
-
1027
- raise RDDLNotImplementedError(
1028
- f'Control operator {op} is not supported.\n' + print_stack_trace(expr))
1311
+ return self._jax_switch(expr, aux)
1029
1312
 
1030
- def _jax_if(self, expr, init_params):
1313
+ def _jax_if(self, expr, aux):
1031
1314
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
1032
1315
  JaxRDDLCompiler._check_num_args(expr, 3)
1033
1316
  pred, if_true, if_false = expr.args
1034
1317
 
1035
- # if predicate is non-fluent, always use the exact operation
1036
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1037
- if_op = self.EXACT_OPS['control']['if']
1038
- else:
1039
- if_op = self.OPS['control']['if']
1040
- jax_op = if_op(expr.id, init_params)
1041
-
1042
1318
  # recursively compile arguments
1043
- jax_pred = self._jax(pred, init_params)
1044
- jax_true = self._jax(if_true, init_params)
1045
- jax_false = self._jax(if_false, init_params)
1046
-
1047
- def _jax_wrapped_if_then_else(x, params, key):
1048
- sample1, key, err1, params = jax_pred(x, params, key)
1049
- sample2, key, err2, params = jax_true(x, params, key)
1050
- sample3, key, err3, params = jax_false(x, params, key)
1051
- sample, params = jax_op(sample1, sample2, sample3, params)
1319
+ jax_pred = self._jax(pred, aux)
1320
+ jax_true = self._jax(if_true, aux)
1321
+ jax_false = self._jax(if_false, aux)
1322
+
1323
+ def _jax_wrapped_if_then_else(fls, nfls, params, key):
1324
+ sample1, key, err1, params = jax_pred(fls, nfls, params, key)
1325
+ sample2, key, err2, params = jax_true(fls, nfls, params, key)
1326
+ sample3, key, err3, params = jax_false(fls, nfls, params, key)
1327
+ sample = jnp.where(sample1 > 0.5, sample2, sample3)
1052
1328
  err = err1 | err2 | err3
1053
- invalid_cast = jnp.logical_not(jnp.can_cast(sample1, bool))
1054
- err |= (invalid_cast * ERR)
1329
+ invalid_cast = jnp.logical_not(jnp.can_cast(sample1, jnp.bool_))
1330
+ err = err | (invalid_cast * ERR)
1055
1331
  return sample, key, err, params
1056
-
1057
1332
  return _jax_wrapped_if_then_else
1058
1333
 
1059
- def _jax_switch(self, expr, init_params):
1060
- pred, *_ = expr.args
1061
-
1062
- # if predicate is non-fluent, always use the exact operation
1063
- # case conditions are currently only literals so they are non-fluent
1064
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1065
- switch_op = self.EXACT_OPS['control']['switch']
1066
- else:
1067
- switch_op = self.OPS['control']['switch']
1068
- jax_op = switch_op(expr.id, init_params)
1069
-
1334
+ def _jax_switch(self, expr, aux):
1335
+
1070
1336
  # recursively compile predicate
1071
- jax_pred = self._jax(pred, init_params)
1337
+ pred = expr.args[0]
1338
+ jax_pred = self._jax(pred, aux)
1072
1339
 
1073
1340
  # recursively compile cases
1074
1341
  cases, default = self.traced.cached_sim_info(expr)
1075
- jax_default = None if default is None else self._jax(default, init_params)
1076
- jax_cases = [(jax_default if _case is None else self._jax(_case, init_params))
1077
- for _case in cases]
1342
+ jax_default = None if default is None else self._jax(default, aux)
1343
+ jax_cases = [
1344
+ (jax_default if _case is None else self._jax(_case, aux))
1345
+ for _case in cases
1346
+ ]
1078
1347
 
1079
- def _jax_wrapped_switch(x, params, key):
1348
+ def _jax_wrapped_switch(fls, nfls, params, key):
1080
1349
 
1081
1350
  # sample predicate
1082
- sample_pred, key, err, params = jax_pred(x, params, key)
1351
+ sample_pred, key, err, params = jax_pred(fls, nfls, params, key)
1083
1352
 
1084
1353
  # sample cases
1085
- sample_cases = [None] * len(jax_cases)
1086
- for (i, jax_case) in enumerate(jax_cases):
1087
- sample_cases[i], key, err_case, params = jax_case(x, params, key)
1088
- err |= err_case
1354
+ sample_cases = []
1355
+ for jax_case in jax_cases:
1356
+ sample, key, err_case, params = jax_case(fls, nfls, params, key)
1357
+ sample_cases.append(sample)
1358
+ err = err | err_case
1089
1359
  sample_cases = jnp.asarray(sample_cases)
1090
1360
  sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
1091
1361
 
1092
1362
  # predicate (enum) is an integer - use it to extract from case array
1093
- sample, params = jax_op(sample_pred, sample_cases, params)
1363
+ sample_pred = jnp.asarray(sample_pred[jnp.newaxis, ...], dtype=self.INT)
1364
+ sample = jnp.take_along_axis(sample_cases, sample_pred, axis=0)
1365
+ assert sample.shape[0] == 1
1366
+ sample = sample[0, ...]
1094
1367
  return sample, key, err, params
1095
-
1096
1368
  return _jax_wrapped_switch
1097
1369
 
1098
1370
  # ===========================================================================
1099
1371
  # random variables
1100
1372
  # ===========================================================================
1101
1373
 
1102
- # distributions with complete reparameterization support:
1103
- # KronDelta: complete
1104
- # DiracDelta: complete
1105
- # Uniform: complete
1106
- # Bernoulli: complete (subclass uses Gumbel-softmax)
1107
- # Normal: complete
1108
- # Exponential: complete
1109
- # Geometric: complete
1110
- # Weibull: complete
1111
- # Pareto: complete
1112
- # Gumbel: complete
1113
- # Laplace: complete
1114
- # Cauchy: complete
1115
- # Gompertz: complete
1116
- # Kumaraswamy: complete
1117
- # Discrete: complete (subclass uses Gumbel-softmax)
1118
- # UnnormDiscrete: complete (subclass uses Gumbel-softmax)
1119
- # Discrete(p): complete (subclass uses Gumbel-softmax)
1120
- # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1121
- # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1122
- # Binomial (subclass uses Gumbel-softmax or Normal approximation)
1123
- # NegativeBinomial (subclass uses Poisson-Gamma mixture)
1124
-
1125
1374
  # distributions which seem to support backpropagation (need more testing):
1126
1375
  # Beta
1127
1376
  # Student
@@ -1132,656 +1381,587 @@ class JaxRDDLCompiler:
1132
1381
  # distributions with incomplete reparameterization support (TODO):
1133
1382
  # Multinomial
1134
1383
 
1135
- def _jax_random(self, expr, init_params):
1384
+ def _jax_random(self, expr, aux):
1136
1385
  _, name = expr.etype
1137
1386
  if name == 'KronDelta':
1138
- return self._jax_kron(expr, init_params)
1387
+ return self._jax_kron(expr, aux)
1139
1388
  elif name == 'DiracDelta':
1140
- return self._jax_dirac(expr, init_params)
1389
+ return self._jax_dirac(expr, aux)
1141
1390
  elif name == 'Uniform':
1142
- return self._jax_uniform(expr, init_params)
1391
+ return self._jax_uniform(expr, aux)
1143
1392
  elif name == 'Bernoulli':
1144
- return self._jax_bernoulli(expr, init_params)
1393
+ return self._jax_bernoulli(expr, aux)
1145
1394
  elif name == 'Normal':
1146
- return self._jax_normal(expr, init_params)
1395
+ return self._jax_normal(expr, aux)
1147
1396
  elif name == 'Poisson':
1148
- return self._jax_poisson(expr, init_params)
1397
+ return self._jax_poisson(expr, aux)
1149
1398
  elif name == 'Exponential':
1150
- return self._jax_exponential(expr, init_params)
1399
+ return self._jax_exponential(expr, aux)
1151
1400
  elif name == 'Weibull':
1152
- return self._jax_weibull(expr, init_params)
1401
+ return self._jax_weibull(expr, aux)
1153
1402
  elif name == 'Gamma':
1154
- return self._jax_gamma(expr, init_params)
1403
+ return self._jax_gamma(expr, aux)
1155
1404
  elif name == 'Binomial':
1156
- return self._jax_binomial(expr, init_params)
1405
+ return self._jax_binomial(expr, aux)
1157
1406
  elif name == 'NegativeBinomial':
1158
- return self._jax_negative_binomial(expr, init_params)
1407
+ return self._jax_negative_binomial(expr, aux)
1159
1408
  elif name == 'Beta':
1160
- return self._jax_beta(expr, init_params)
1409
+ return self._jax_beta(expr, aux)
1161
1410
  elif name == 'Geometric':
1162
- return self._jax_geometric(expr, init_params)
1411
+ return self._jax_geometric(expr, aux)
1163
1412
  elif name == 'Pareto':
1164
- return self._jax_pareto(expr, init_params)
1413
+ return self._jax_pareto(expr, aux)
1165
1414
  elif name == 'Student':
1166
- return self._jax_student(expr, init_params)
1415
+ return self._jax_student(expr, aux)
1167
1416
  elif name == 'Gumbel':
1168
- return self._jax_gumbel(expr, init_params)
1417
+ return self._jax_gumbel(expr, aux)
1169
1418
  elif name == 'Laplace':
1170
- return self._jax_laplace(expr, init_params)
1419
+ return self._jax_laplace(expr, aux)
1171
1420
  elif name == 'Cauchy':
1172
- return self._jax_cauchy(expr, init_params)
1421
+ return self._jax_cauchy(expr, aux)
1173
1422
  elif name == 'Gompertz':
1174
- return self._jax_gompertz(expr, init_params)
1423
+ return self._jax_gompertz(expr, aux)
1175
1424
  elif name == 'ChiSquare':
1176
- return self._jax_chisquare(expr, init_params)
1425
+ return self._jax_chisquare(expr, aux)
1177
1426
  elif name == 'Kumaraswamy':
1178
- return self._jax_kumaraswamy(expr, init_params)
1427
+ return self._jax_kumaraswamy(expr, aux)
1179
1428
  elif name == 'Discrete':
1180
- return self._jax_discrete(expr, init_params, unnorm=False)
1429
+ return self._jax_discrete(expr, aux, unnorm=False)
1181
1430
  elif name == 'UnnormDiscrete':
1182
- return self._jax_discrete(expr, init_params, unnorm=True)
1431
+ return self._jax_discrete(expr, aux, unnorm=True)
1183
1432
  elif name == 'Discrete(p)':
1184
- return self._jax_discrete_pvar(expr, init_params, unnorm=False)
1433
+ return self._jax_discrete_pvar(expr, aux, unnorm=False)
1185
1434
  elif name == 'UnnormDiscrete(p)':
1186
- return self._jax_discrete_pvar(expr, init_params, unnorm=True)
1435
+ return self._jax_discrete_pvar(expr, aux, unnorm=True)
1187
1436
  else:
1188
1437
  raise RDDLNotImplementedError(
1189
1438
  f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1190
1439
 
1191
- def _jax_kron(self, expr, init_params):
1440
+ def _jax_kron(self, expr, aux):
1192
1441
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
1193
1442
  JaxRDDLCompiler._check_num_args(expr, 1)
1194
1443
  arg, = expr.args
1195
- arg = self._jax(arg, init_params)
1444
+ arg = self._jax(arg, aux)
1196
1445
 
1197
1446
  # just check that the sample can be cast to int
1198
- def _jax_wrapped_distribution_kron(x, params, key):
1199
- sample, key, err, params = arg(x, params, key)
1447
+ def _jax_wrapped_distribution_kron(fls, nfls, params, key):
1448
+ sample, key, err, params = arg(fls, nfls, params, key)
1200
1449
  invalid_cast = jnp.logical_not(jnp.can_cast(sample, self.INT))
1201
- err |= (invalid_cast * ERR)
1450
+ err = err | (invalid_cast * ERR)
1202
1451
  return sample, key, err, params
1203
-
1204
1452
  return _jax_wrapped_distribution_kron
1205
1453
 
1206
- def _jax_dirac(self, expr, init_params):
1454
+ def _jax_dirac(self, expr, aux):
1207
1455
  JaxRDDLCompiler._check_num_args(expr, 1)
1208
1456
  arg, = expr.args
1209
- arg = self._jax(arg, init_params, dtype=self.REAL)
1457
+ arg = self._jax(arg, aux, dtype=self.REAL)
1210
1458
  return arg
1211
1459
 
1212
- def _jax_uniform(self, expr, init_params):
1460
+ def _jax_uniform(self, expr, aux):
1213
1461
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_UNIFORM']
1214
1462
  JaxRDDLCompiler._check_num_args(expr, 2)
1215
1463
 
1216
1464
  arg_lb, arg_ub = expr.args
1217
- jax_lb = self._jax(arg_lb, init_params)
1218
- jax_ub = self._jax(arg_ub, init_params)
1465
+ jax_lb = self._jax(arg_lb, aux)
1466
+ jax_ub = self._jax(arg_ub, aux)
1219
1467
 
1220
1468
  # reparameterization trick U(a, b) = a + (b - a) * U(0, 1)
1221
- def _jax_wrapped_distribution_uniform(x, params, key):
1222
- lb, key, err1, params = jax_lb(x, params, key)
1223
- ub, key, err2, params = jax_ub(x, params, key)
1469
+ def _jax_wrapped_distribution_uniform(fls, nfls, params, key):
1470
+ lb, key, err1, params = jax_lb(fls, nfls, params, key)
1471
+ ub, key, err2, params = jax_ub(fls, nfls, params, key)
1224
1472
  key, subkey = random.split(key)
1225
1473
  U = random.uniform(key=subkey, shape=jnp.shape(lb), dtype=self.REAL)
1226
1474
  sample = lb + (ub - lb) * U
1227
1475
  out_of_bounds = jnp.logical_not(jnp.all(lb <= ub))
1228
1476
  err = err1 | err2 | (out_of_bounds * ERR)
1229
- return sample, key, err, params
1230
-
1477
+ return sample, key, err, params
1231
1478
  return _jax_wrapped_distribution_uniform
1232
1479
 
1233
- def _jax_normal(self, expr, init_params):
1480
+ def _jax_normal(self, expr, aux):
1234
1481
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NORMAL']
1235
1482
  JaxRDDLCompiler._check_num_args(expr, 2)
1236
1483
 
1237
1484
  arg_mean, arg_var = expr.args
1238
- jax_mean = self._jax(arg_mean, init_params)
1239
- jax_var = self._jax(arg_var, init_params)
1485
+ jax_mean = self._jax(arg_mean, aux)
1486
+ jax_var = self._jax(arg_var, aux)
1240
1487
 
1241
1488
  # reparameterization trick N(m, s^2) = m + s * N(0, 1)
1242
- def _jax_wrapped_distribution_normal(x, params, key):
1243
- mean, key, err1, params = jax_mean(x, params, key)
1244
- var, key, err2, params = jax_var(x, params, key)
1489
+ def _jax_wrapped_distribution_normal(fls, nfls, params, key):
1490
+ mean, key, err1, params = jax_mean(fls, nfls, params, key)
1491
+ var, key, err2, params = jax_var(fls, nfls, params, key)
1245
1492
  std = jnp.sqrt(var)
1246
1493
  key, subkey = random.split(key)
1247
1494
  Z = random.normal(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1248
1495
  sample = mean + std * Z
1249
1496
  out_of_bounds = jnp.logical_not(jnp.all(var >= 0))
1250
1497
  err = err1 | err2 | (out_of_bounds * ERR)
1251
- return sample, key, err, params
1252
-
1498
+ return sample, key, err, params
1253
1499
  return _jax_wrapped_distribution_normal
1254
1500
 
1255
- def _jax_exponential(self, expr, init_params):
1501
+ def _jax_exponential(self, expr, aux):
1256
1502
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_EXPONENTIAL']
1257
1503
  JaxRDDLCompiler._check_num_args(expr, 1)
1258
1504
 
1259
1505
  arg_scale, = expr.args
1260
- jax_scale = self._jax(arg_scale, init_params)
1506
+ jax_scale = self._jax(arg_scale, aux)
1261
1507
 
1262
1508
  # reparameterization trick Exp(s) = s * Exp(1)
1263
- def _jax_wrapped_distribution_exp(x, params, key):
1264
- scale, key, err, params = jax_scale(x, params, key)
1509
+ def _jax_wrapped_distribution_exp(fls, nfls, params, key):
1510
+ scale, key, err, params = jax_scale(fls, nfls, params, key)
1265
1511
  key, subkey = random.split(key)
1266
- Exp1 = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1267
- sample = scale * Exp1
1512
+ exp = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1513
+ sample = scale * exp
1268
1514
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1269
- err |= (out_of_bounds * ERR)
1270
- return sample, key, err, params
1271
-
1515
+ err = err | (out_of_bounds * ERR)
1516
+ return sample, key, err, params
1272
1517
  return _jax_wrapped_distribution_exp
1273
1518
 
1274
- def _jax_weibull(self, expr, init_params):
1519
+ def _jax_weibull(self, expr, aux):
1275
1520
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_WEIBULL']
1276
1521
  JaxRDDLCompiler._check_num_args(expr, 2)
1277
1522
 
1278
1523
  arg_shape, arg_scale = expr.args
1279
- jax_shape = self._jax(arg_shape, init_params)
1280
- jax_scale = self._jax(arg_scale, init_params)
1524
+ jax_shape = self._jax(arg_shape, aux)
1525
+ jax_scale = self._jax(arg_scale, aux)
1281
1526
 
1282
1527
  # reparameterization trick W(s, r) = r * (-ln(1 - U(0, 1))) ** (1 / s)
1283
- def _jax_wrapped_distribution_weibull(x, params, key):
1284
- shape, key, err1, params = jax_shape(x, params, key)
1285
- scale, key, err2, params = jax_scale(x, params, key)
1528
+ def _jax_wrapped_distribution_weibull(fls, nfls, params, key):
1529
+ shape, key, err1, params = jax_shape(fls, nfls, params, key)
1530
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1286
1531
  key, subkey = random.split(key)
1287
1532
  sample = random.weibull_min(
1288
1533
  key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
1289
- out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1534
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
1290
1535
  err = err1 | err2 | (out_of_bounds * ERR)
1291
- return sample, key, err, params
1292
-
1536
+ return sample, key, err, params
1293
1537
  return _jax_wrapped_distribution_weibull
1294
1538
 
1295
- def _jax_bernoulli(self, expr, init_params):
1539
+ def _jax_bernoulli(self, expr, aux):
1296
1540
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1297
1541
  JaxRDDLCompiler._check_num_args(expr, 1)
1298
1542
  arg_prob, = expr.args
1299
1543
 
1300
- # if probability is non-fluent, always use the exact operation
1301
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1302
- bern_op = self.EXACT_OPS['sampling']['Bernoulli']
1303
- else:
1304
- bern_op = self.OPS['sampling']['Bernoulli']
1305
- jax_op = bern_op(expr.id, init_params)
1306
-
1307
1544
  # recursively compile arguments
1308
- jax_prob = self._jax(arg_prob, init_params)
1545
+ jax_prob = self._jax(arg_prob, aux)
1309
1546
 
1310
- def _jax_wrapped_distribution_bernoulli(x, params, key):
1311
- prob, key, err, params = jax_prob(x, params, key)
1547
+ def _jax_wrapped_distribution_bernoulli(fls, nfls, params, key):
1548
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1312
1549
  key, subkey = random.split(key)
1313
- sample, params = jax_op(subkey, prob, params)
1314
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1315
- err |= (out_of_bounds * ERR)
1316
- return sample, key, err, params
1317
-
1550
+ sample = random.bernoulli(subkey, prob)
1551
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1552
+ err = err | (out_of_bounds * ERR)
1553
+ return sample, key, err, params
1318
1554
  return _jax_wrapped_distribution_bernoulli
1319
1555
 
1320
- def _jax_poisson(self, expr, init_params):
1556
+ def _jax_poisson(self, expr, aux):
1321
1557
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
1322
1558
  JaxRDDLCompiler._check_num_args(expr, 1)
1323
1559
  arg_rate, = expr.args
1324
1560
 
1325
- # if rate is non-fluent, always use the exact operation
1326
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1327
- poisson_op = self.EXACT_OPS['sampling']['Poisson']
1328
- else:
1329
- poisson_op = self.OPS['sampling']['Poisson']
1330
- jax_op = poisson_op(expr.id, init_params)
1331
-
1332
1561
  # recursively compile arguments
1333
- jax_rate = self._jax(arg_rate, init_params)
1562
+ jax_rate = self._jax(arg_rate, aux)
1334
1563
 
1335
1564
  # uses the implicit JAX subroutine
1336
- def _jax_wrapped_distribution_poisson(x, params, key):
1337
- rate, key, err, params = jax_rate(x, params, key)
1565
+ def _jax_wrapped_distribution_poisson(fls, nfls, params, key):
1566
+ rate, key, err, params = jax_rate(fls, nfls, params, key)
1338
1567
  key, subkey = random.split(key)
1339
- sample, params = jax_op(subkey, rate, params)
1568
+ sample = random.poisson(key=subkey, lam=rate, dtype=self.INT)
1340
1569
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1341
- err |= (out_of_bounds * ERR)
1342
- return sample, key, err, params
1343
-
1570
+ err = err | (out_of_bounds * ERR)
1571
+ return sample, key, err, params
1344
1572
  return _jax_wrapped_distribution_poisson
1345
1573
 
1346
- def _jax_gamma(self, expr, init_params):
1574
+ def _jax_gamma(self, expr, aux):
1347
1575
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GAMMA']
1348
1576
  JaxRDDLCompiler._check_num_args(expr, 2)
1349
1577
 
1350
1578
  arg_shape, arg_scale = expr.args
1351
- jax_shape = self._jax(arg_shape, init_params)
1352
- jax_scale = self._jax(arg_scale, init_params)
1579
+ jax_shape = self._jax(arg_shape, aux)
1580
+ jax_scale = self._jax(arg_scale, aux)
1353
1581
 
1354
1582
  # partial reparameterization trick Gamma(s, r) = r * Gamma(s, 1)
1355
1583
  # uses the implicit JAX subroutine for Gamma(s, 1)
1356
- def _jax_wrapped_distribution_gamma(x, params, key):
1357
- shape, key, err1, params = jax_shape(x, params, key)
1358
- scale, key, err2, params = jax_scale(x, params, key)
1584
+ def _jax_wrapped_distribution_gamma(fls, nfls, params, key):
1585
+ shape, key, err1, params = jax_shape(fls, nfls, params, key)
1586
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1359
1587
  key, subkey = random.split(key)
1360
1588
  Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
1361
1589
  sample = scale * Gamma
1362
- out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1590
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
1363
1591
  err = err1 | err2 | (out_of_bounds * ERR)
1364
- return sample, key, err, params
1365
-
1592
+ return sample, key, err, params
1366
1593
  return _jax_wrapped_distribution_gamma
1367
1594
 
1368
- def _jax_binomial(self, expr, init_params):
1595
+ def _jax_binomial(self, expr, aux):
1369
1596
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1370
1597
  JaxRDDLCompiler._check_num_args(expr, 2)
1371
1598
  arg_trials, arg_prob = expr.args
1372
1599
 
1373
- # if prob is non-fluent, always use the exact operation
1374
- if self.compile_non_fluent_exact \
1375
- and not self.traced.cached_is_fluent(arg_trials) \
1376
- and not self.traced.cached_is_fluent(arg_prob):
1377
- bin_op = self.EXACT_OPS['sampling']['Binomial']
1378
- else:
1379
- bin_op = self.OPS['sampling']['Binomial']
1380
- jax_op = bin_op(expr.id, init_params)
1381
-
1382
- jax_trials = self._jax(arg_trials, init_params)
1383
- jax_prob = self._jax(arg_prob, init_params)
1600
+ jax_trials = self._jax(arg_trials, aux)
1601
+ jax_prob = self._jax(arg_prob, aux)
1384
1602
 
1385
1603
  # uses reduction for constant trials
1386
- def _jax_wrapped_distribution_binomial(x, params, key):
1387
- trials, key, err2, params = jax_trials(x, params, key)
1388
- prob, key, err1, params = jax_prob(x, params, key)
1604
+ def _jax_wrapped_distribution_binomial(fls, nfls, params, key):
1605
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1606
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1389
1607
  key, subkey = random.split(key)
1390
- sample, params = jax_op(subkey, trials, prob, params)
1608
+ trials = jnp.asarray(trials, dtype=self.REAL)
1609
+ prob = jnp.asarray(prob, dtype=self.REAL)
1610
+ sample = random.binomial(key=subkey, n=trials, p=prob, dtype=self.REAL)
1611
+ sample = jnp.asarray(sample, dtype=self.INT)
1391
1612
  out_of_bounds = jnp.logical_not(jnp.all(
1392
- (prob >= 0) & (prob <= 1) & (trials >= 0)))
1613
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
1393
1614
  err = err1 | err2 | (out_of_bounds * ERR)
1394
- return sample, key, err, params
1395
-
1615
+ return sample, key, err, params
1396
1616
  return _jax_wrapped_distribution_binomial
1397
1617
 
1398
- def _jax_negative_binomial(self, expr, init_params):
1618
+ def _jax_negative_binomial(self, expr, aux):
1399
1619
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1400
1620
  JaxRDDLCompiler._check_num_args(expr, 2)
1401
1621
  arg_trials, arg_prob = expr.args
1402
1622
 
1403
- # if prob is non-fluent, always use the exact operation
1404
- if self.compile_non_fluent_exact \
1405
- and not self.traced.cached_is_fluent(arg_trials) \
1406
- and not self.traced.cached_is_fluent(arg_prob):
1407
- negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
1408
- else:
1409
- negbin_op = self.OPS['sampling']['NegativeBinomial']
1410
- jax_op = negbin_op(expr.id, init_params)
1411
-
1412
- jax_trials = self._jax(arg_trials, init_params)
1413
- jax_prob = self._jax(arg_prob, init_params)
1623
+ jax_trials = self._jax(arg_trials, aux)
1624
+ jax_prob = self._jax(arg_prob, aux)
1414
1625
 
1415
- # uses the JAX substrate of tensorflow-probability
1416
- def _jax_wrapped_distribution_negative_binomial(x, params, key):
1417
- trials, key, err2, params = jax_trials(x, params, key)
1418
- prob, key, err1, params = jax_prob(x, params, key)
1626
+ # uses tensorflow-probability
1627
+ def _jax_wrapped_distribution_negative_binomial(fls, nfls, params, key):
1628
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1629
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1419
1630
  key, subkey = random.split(key)
1420
- sample, params = jax_op(subkey, trials, prob, params)
1631
+ trials = jnp.asarray(trials, dtype=self.REAL)
1632
+ prob = jnp.asarray(prob, dtype=self.REAL)
1633
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1. - prob)
1634
+ sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1421
1635
  out_of_bounds = jnp.logical_not(jnp.all(
1422
- (prob >= 0) & (prob <= 1) & (trials > 0)))
1636
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
1423
1637
  err = err1 | err2 | (out_of_bounds * ERR)
1424
- return sample, key, err, params
1425
-
1638
+ return sample, key, err, params
1426
1639
  return _jax_wrapped_distribution_negative_binomial
1427
1640
 
1428
- def _jax_beta(self, expr, init_params):
1641
+ def _jax_beta(self, expr, aux):
1429
1642
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BETA']
1430
1643
  JaxRDDLCompiler._check_num_args(expr, 2)
1431
1644
 
1432
1645
  arg_shape, arg_rate = expr.args
1433
- jax_shape = self._jax(arg_shape, init_params)
1434
- jax_rate = self._jax(arg_rate, init_params)
1646
+ jax_shape = self._jax(arg_shape, aux)
1647
+ jax_rate = self._jax(arg_rate, aux)
1435
1648
 
1436
1649
  # uses the implicit JAX subroutine
1437
- def _jax_wrapped_distribution_beta(x, params, key):
1438
- shape, key, err1, params = jax_shape(x, params, key)
1439
- rate, key, err2, params = jax_rate(x, params, key)
1650
+ def _jax_wrapped_distribution_beta(fls, nfls, params, key):
1651
+ shape, key, err1, params = jax_shape(fls, nfls, params, key)
1652
+ rate, key, err2, params = jax_rate(fls, nfls, params, key)
1440
1653
  key, subkey = random.split(key)
1441
1654
  sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
1442
- out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
1655
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, rate > 0)))
1443
1656
  err = err1 | err2 | (out_of_bounds * ERR)
1444
- return sample, key, err, params
1445
-
1657
+ return sample, key, err, params
1446
1658
  return _jax_wrapped_distribution_beta
1447
1659
 
1448
- def _jax_geometric(self, expr, init_params):
1660
+ def _jax_geometric(self, expr, aux):
1449
1661
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1450
1662
  JaxRDDLCompiler._check_num_args(expr, 1)
1451
1663
  arg_prob, = expr.args
1452
1664
 
1453
- # if prob is non-fluent, always use the exact operation
1454
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1455
- geom_op = self.EXACT_OPS['sampling']['Geometric']
1456
- else:
1457
- geom_op = self.OPS['sampling']['Geometric']
1458
- jax_op = geom_op(expr.id, init_params)
1459
-
1460
1665
  # recursively compile arguments
1461
- jax_prob = self._jax(arg_prob, init_params)
1666
+ jax_prob = self._jax(arg_prob, aux)
1462
1667
 
1463
- def _jax_wrapped_distribution_geometric(x, params, key):
1464
- prob, key, err, params = jax_prob(x, params, key)
1668
+ def _jax_wrapped_distribution_geometric(fls, nfls, params, key):
1669
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1465
1670
  key, subkey = random.split(key)
1466
- sample, params = jax_op(subkey, prob, params)
1467
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1468
- err |= (out_of_bounds * ERR)
1469
- return sample, key, err, params
1470
-
1671
+ sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1672
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1673
+ err = err | (out_of_bounds * ERR)
1674
+ return sample, key, err, params
1471
1675
  return _jax_wrapped_distribution_geometric
1472
1676
 
1473
- def _jax_pareto(self, expr, init_params):
1677
+ def _jax_pareto(self, expr, aux):
1474
1678
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_PARETO']
1475
1679
  JaxRDDLCompiler._check_num_args(expr, 2)
1476
1680
 
1477
1681
  arg_shape, arg_scale = expr.args
1478
- jax_shape = self._jax(arg_shape, init_params)
1479
- jax_scale = self._jax(arg_scale, init_params)
1682
+ jax_shape = self._jax(arg_shape, aux)
1683
+ jax_scale = self._jax(arg_scale, aux)
1480
1684
 
1481
1685
  # partial reparameterization trick Pareto(s, r) = r * Pareto(s, 1)
1482
1686
  # uses the implicit JAX subroutine for Pareto(s, 1)
1483
- def _jax_wrapped_distribution_pareto(x, params, key):
1484
- shape, key, err1, params = jax_shape(x, params, key)
1485
- scale, key, err2, params = jax_scale(x, params, key)
1687
+ def _jax_wrapped_distribution_pareto(fls, nfls, params, key):
1688
+ shape, key, err1, params = jax_shape(fls, nfls, params, key)
1689
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1486
1690
  key, subkey = random.split(key)
1487
1691
  sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
1488
- out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1692
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
1489
1693
  err = err1 | err2 | (out_of_bounds * ERR)
1490
- return sample, key, err, params
1491
-
1694
+ return sample, key, err, params
1492
1695
  return _jax_wrapped_distribution_pareto
1493
1696
 
1494
- def _jax_student(self, expr, init_params):
1697
+ def _jax_student(self, expr, aux):
1495
1698
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_STUDENT']
1496
1699
  JaxRDDLCompiler._check_num_args(expr, 1)
1497
1700
 
1498
1701
  arg_df, = expr.args
1499
- jax_df = self._jax(arg_df, init_params)
1702
+ jax_df = self._jax(arg_df, aux)
1500
1703
 
1501
1704
  # uses the implicit JAX subroutine for student(df)
1502
- def _jax_wrapped_distribution_t(x, params, key):
1503
- df, key, err, params = jax_df(x, params, key)
1705
+ def _jax_wrapped_distribution_t(fls, nfls, params, key):
1706
+ df, key, err, params = jax_df(fls, nfls, params, key)
1504
1707
  key, subkey = random.split(key)
1505
1708
  sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1506
1709
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1507
- err |= (out_of_bounds * ERR)
1508
- return sample, key, err, params
1509
-
1710
+ err = err | (out_of_bounds * ERR)
1711
+ return sample, key, err, params
1510
1712
  return _jax_wrapped_distribution_t
1511
1713
 
1512
- def _jax_gumbel(self, expr, init_params):
1714
+ def _jax_gumbel(self, expr, aux):
1513
1715
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GUMBEL']
1514
1716
  JaxRDDLCompiler._check_num_args(expr, 2)
1515
1717
 
1516
1718
  arg_mean, arg_scale = expr.args
1517
- jax_mean = self._jax(arg_mean, init_params)
1518
- jax_scale = self._jax(arg_scale, init_params)
1719
+ jax_mean = self._jax(arg_mean, aux)
1720
+ jax_scale = self._jax(arg_scale, aux)
1519
1721
 
1520
1722
  # reparameterization trick Gumbel(m, s) = m + s * Gumbel(0, 1)
1521
- def _jax_wrapped_distribution_gumbel(x, params, key):
1522
- mean, key, err1, params = jax_mean(x, params, key)
1523
- scale, key, err2, params = jax_scale(x, params, key)
1723
+ def _jax_wrapped_distribution_gumbel(fls, nfls, params, key):
1724
+ mean, key, err1, params = jax_mean(fls, nfls, params, key)
1725
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1524
1726
  key, subkey = random.split(key)
1525
- Gumbel01 = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1526
- sample = mean + scale * Gumbel01
1727
+ g = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1728
+ sample = mean + scale * g
1527
1729
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1528
1730
  err = err1 | err2 | (out_of_bounds * ERR)
1529
- return sample, key, err, params
1530
-
1731
+ return sample, key, err, params
1531
1732
  return _jax_wrapped_distribution_gumbel
1532
1733
 
1533
- def _jax_laplace(self, expr, init_params):
1734
+ def _jax_laplace(self, expr, aux):
1534
1735
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_LAPLACE']
1535
1736
  JaxRDDLCompiler._check_num_args(expr, 2)
1536
1737
 
1537
1738
  arg_mean, arg_scale = expr.args
1538
- jax_mean = self._jax(arg_mean, init_params)
1539
- jax_scale = self._jax(arg_scale, init_params)
1739
+ jax_mean = self._jax(arg_mean, aux)
1740
+ jax_scale = self._jax(arg_scale, aux)
1540
1741
 
1541
1742
  # reparameterization trick Laplace(m, s) = m + s * Laplace(0, 1)
1542
- def _jax_wrapped_distribution_laplace(x, params, key):
1543
- mean, key, err1, params = jax_mean(x, params, key)
1544
- scale, key, err2, params = jax_scale(x, params, key)
1743
+ def _jax_wrapped_distribution_laplace(fls, nfls, params, key):
1744
+ mean, key, err1, params = jax_mean(fls, nfls, params, key)
1745
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1545
1746
  key, subkey = random.split(key)
1546
- Laplace01 = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1547
- sample = mean + scale * Laplace01
1747
+ lp = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1748
+ sample = mean + scale * lp
1548
1749
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1549
1750
  err = err1 | err2 | (out_of_bounds * ERR)
1550
- return sample, key, err, params
1551
-
1751
+ return sample, key, err, params
1552
1752
  return _jax_wrapped_distribution_laplace
1553
1753
 
1554
- def _jax_cauchy(self, expr, init_params):
1754
+ def _jax_cauchy(self, expr, aux):
1555
1755
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CAUCHY']
1556
1756
  JaxRDDLCompiler._check_num_args(expr, 2)
1557
1757
 
1558
1758
  arg_mean, arg_scale = expr.args
1559
- jax_mean = self._jax(arg_mean, init_params)
1560
- jax_scale = self._jax(arg_scale, init_params)
1759
+ jax_mean = self._jax(arg_mean, aux)
1760
+ jax_scale = self._jax(arg_scale, aux)
1561
1761
 
1562
1762
  # reparameterization trick Cauchy(m, s) = m + s * Cauchy(0, 1)
1563
- def _jax_wrapped_distribution_cauchy(x, params, key):
1564
- mean, key, err1, params = jax_mean(x, params, key)
1565
- scale, key, err2, params = jax_scale(x, params, key)
1763
+ def _jax_wrapped_distribution_cauchy(fls, nfls, params, key):
1764
+ mean, key, err1, params = jax_mean(fls, nfls, params, key)
1765
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1566
1766
  key, subkey = random.split(key)
1567
- Cauchy01 = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1568
- sample = mean + scale * Cauchy01
1767
+ cauchy = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1768
+ sample = mean + scale * cauchy
1569
1769
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1570
1770
  err = err1 | err2 | (out_of_bounds * ERR)
1571
- return sample, key, err, params
1572
-
1771
+ return sample, key, err, params
1573
1772
  return _jax_wrapped_distribution_cauchy
1574
1773
 
1575
- def _jax_gompertz(self, expr, init_params):
1774
+ def _jax_gompertz(self, expr, aux):
1576
1775
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GOMPERTZ']
1577
1776
  JaxRDDLCompiler._check_num_args(expr, 2)
1578
1777
 
1579
1778
  arg_shape, arg_scale = expr.args
1580
- jax_shape = self._jax(arg_shape, init_params)
1581
- jax_scale = self._jax(arg_scale, init_params)
1779
+ jax_shape = self._jax(arg_shape, aux)
1780
+ jax_scale = self._jax(arg_scale, aux)
1582
1781
 
1583
1782
  # reparameterization trick Gompertz(s, r) = ln(1 - log(U(0, 1)) / s) / r
1584
- def _jax_wrapped_distribution_gompertz(x, params, key):
1585
- shape, key, err1, params = jax_shape(x, params, key)
1586
- scale, key, err2, params = jax_scale(x, params, key)
1783
+ def _jax_wrapped_distribution_gompertz(fls, nfls, params, key):
1784
+ shape, key, err1, params = jax_shape(fls, nfls, params, key)
1785
+ scale, key, err2, params = jax_scale(fls, nfls, params, key)
1587
1786
  key, subkey = random.split(key)
1588
1787
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1589
1788
  sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
1590
- out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1789
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
1591
1790
  err = err1 | err2 | (out_of_bounds * ERR)
1592
- return sample, key, err, params
1593
-
1791
+ return sample, key, err, params
1594
1792
  return _jax_wrapped_distribution_gompertz
1595
1793
 
1596
- def _jax_chisquare(self, expr, init_params):
1794
+ def _jax_chisquare(self, expr, aux):
1597
1795
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CHISQUARE']
1598
1796
  JaxRDDLCompiler._check_num_args(expr, 1)
1599
1797
 
1600
1798
  arg_df, = expr.args
1601
- jax_df = self._jax(arg_df, init_params)
1799
+ jax_df = self._jax(arg_df, aux)
1602
1800
 
1603
1801
  # use the fact that ChiSquare(df) = Gamma(df/2, 2)
1604
- def _jax_wrapped_distribution_chisquare(x, params, key):
1605
- df, key, err1, params = jax_df(x, params, key)
1802
+ def _jax_wrapped_distribution_chisquare(fls, nfls, params, key):
1803
+ df, key, err1, params = jax_df(fls, nfls, params, key)
1606
1804
  key, subkey = random.split(key)
1607
- shape = df / 2.0
1608
- Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
1609
- sample = 2.0 * Gamma
1805
+ shape = 0.5 * df
1806
+ sample = 2.0 * random.gamma(key=subkey, a=shape, dtype=self.REAL)
1610
1807
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1611
1808
  err = err1 | (out_of_bounds * ERR)
1612
- return sample, key, err, params
1613
-
1809
+ return sample, key, err, params
1614
1810
  return _jax_wrapped_distribution_chisquare
1615
1811
 
1616
- def _jax_kumaraswamy(self, expr, init_params):
1812
+ def _jax_kumaraswamy(self, expr, aux):
1617
1813
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KUMARASWAMY']
1618
1814
  JaxRDDLCompiler._check_num_args(expr, 2)
1619
1815
 
1620
1816
  arg_a, arg_b = expr.args
1621
- jax_a = self._jax(arg_a, init_params)
1622
- jax_b = self._jax(arg_b, init_params)
1817
+ jax_a = self._jax(arg_a, aux)
1818
+ jax_b = self._jax(arg_b, aux)
1623
1819
 
1624
1820
  # uses the reparameterization K(a, b) = (1 - (1 - U(0, 1))^{1/b})^{1/a}
1625
- def _jax_wrapped_distribution_kumaraswamy(x, params, key):
1626
- a, key, err1, params = jax_a(x, params, key)
1627
- b, key, err2, params = jax_b(x, params, key)
1821
+ def _jax_wrapped_distribution_kumaraswamy(fls, nfls, params, key):
1822
+ a, key, err1, params = jax_a(fls, nfls, params, key)
1823
+ b, key, err2, params = jax_b(fls, nfls, params, key)
1628
1824
  key, subkey = random.split(key)
1629
1825
  U = random.uniform(key=subkey, shape=jnp.shape(a), dtype=self.REAL)
1630
1826
  sample = jnp.power(1.0 - jnp.power(U, 1.0 / b), 1.0 / a)
1631
- out_of_bounds = jnp.logical_not(jnp.all((a > 0) & (b > 0)))
1827
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(a > 0, b > 0)))
1632
1828
  err = err1 | err2 | (out_of_bounds * ERR)
1633
- return sample, key, err, params
1634
-
1829
+ return sample, key, err, params
1635
1830
  return _jax_wrapped_distribution_kumaraswamy
1636
1831
 
1637
1832
  # ===========================================================================
1638
1833
  # random variables with enum support
1639
1834
  # ===========================================================================
1640
1835
 
1641
- def _jax_discrete(self, expr, init_params, unnorm):
1642
- NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
1836
+ @staticmethod
1837
+ def _jax_update_discrete_oob_error(err, prob):
1643
1838
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1644
- ordered_args = self.traced.cached_sim_info(expr)
1645
-
1646
- # if all probabilities are non-fluent, then always sample exact
1647
- has_fluent_arg = any(self.traced.cached_is_fluent(arg)
1648
- for arg in ordered_args)
1649
- if self.compile_non_fluent_exact and not has_fluent_arg:
1650
- discrete_op = self.EXACT_OPS['sampling']['Discrete']
1651
- else:
1652
- discrete_op = self.OPS['sampling']['Discrete']
1653
- jax_op = discrete_op(expr.id, init_params)
1654
-
1655
- # compile probability expressions
1656
- jax_probs = [self._jax(arg, init_params) for arg in ordered_args]
1657
-
1658
- def _jax_wrapped_distribution_discrete(x, params, key):
1659
-
1660
- # sample case probabilities and normalize as needed
1661
- error = NORMAL
1662
- prob = [None] * len(jax_probs)
1663
- for (i, jax_prob) in enumerate(jax_probs):
1664
- prob[i], key, error_pdf, params = jax_prob(x, params, key)
1665
- error |= error_pdf
1839
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1840
+ jnp.all(prob >= 0),
1841
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1842
+ ))
1843
+ error = err | (out_of_bounds * ERR)
1844
+ return error
1845
+
1846
+ def _jax_discrete_prob(self, jax_probs, unnormalized):
1847
+ def _jax_wrapped_calc_discrete_prob(fls, nfls, params, key):
1848
+
1849
+ # calculate probability expressions
1850
+ error = JaxRDDLCompiler.ERROR_CODES['NORMAL']
1851
+ prob = []
1852
+ for jax_prob in jax_probs:
1853
+ sample, key, error_pdf, params = jax_prob(fls, nfls, params, key)
1854
+ prob.append(sample)
1855
+ error = error | error_pdf
1666
1856
  prob = jnp.stack(prob, axis=-1)
1667
- if unnorm:
1857
+
1858
+ # normalize them if required
1859
+ if unnormalized:
1668
1860
  normalizer = jnp.sum(prob, axis=-1, keepdims=True)
1669
1861
  prob = prob / normalizer
1670
-
1671
- # dispatch to sampling subroutine
1672
- key, subkey = random.split(key)
1673
- sample, params = jax_op(subkey, prob, params)
1674
- out_of_bounds = jnp.logical_not(jnp.logical_and(
1675
- jnp.all(prob >= 0),
1676
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1677
- ))
1678
- error |= (out_of_bounds * ERR)
1679
- return sample, key, error, params
1862
+ return prob, key, error, params
1863
+ return _jax_wrapped_calc_discrete_prob
1864
+
1865
+ def _jax_discrete(self, expr, aux, unnorm):
1866
+ ordered_args = self.traced.cached_sim_info(expr)
1867
+ jax_probs = [self._jax(arg, aux) for arg in ordered_args]
1868
+ prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
1680
1869
 
1870
+ def _jax_wrapped_distribution_discrete(fls, nfls, params, key):
1871
+ prob, key, error, params = prob_fn(fls, nfls, params, key)
1872
+ key, subkey = random.split(key)
1873
+ sample = random.categorical(key=subkey, logits=jnp.log(prob), axis=-1)
1874
+ error = JaxRDDLCompiler._jax_update_discrete_oob_error(error, prob)
1875
+ return sample, key, error, params
1681
1876
  return _jax_wrapped_distribution_discrete
1682
-
1683
- def _jax_discrete_pvar(self, expr, init_params, unnorm):
1684
- ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
1877
+
1878
+ @staticmethod
1879
+ def _jax_discrete_pvar_prob(jax_probs, unnormalized):
1880
+ def _jax_wrapped_calc_discrete_prob(fls, nfls, params, key):
1881
+ prob, key, error, params = jax_probs(fls, nfls, params, key)
1882
+ if unnormalized:
1883
+ normalizer = jnp.sum(prob, axis=-1, keepdims=True)
1884
+ prob = prob / normalizer
1885
+ return prob, key, error, params
1886
+ return _jax_wrapped_calc_discrete_prob
1887
+
1888
+ def _jax_discrete_pvar(self, expr, aux, unnorm):
1685
1889
  JaxRDDLCompiler._check_num_args(expr, 2)
1686
1890
  _, args = expr.args
1687
1891
  arg, = args
1688
-
1689
- # if probabilities are non-fluent, then always sample exact
1690
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
1691
- discrete_op = self.EXACT_OPS['sampling']['Discrete']
1692
- else:
1693
- discrete_op = self.OPS['sampling']['Discrete']
1694
- jax_op = discrete_op(expr.id, init_params)
1695
-
1696
- # compile probability function
1697
- jax_probs = self._jax(arg, init_params)
1892
+ jax_probs = self._jax(arg, aux)
1893
+ prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
1698
1894
 
1699
- def _jax_wrapped_distribution_discrete_pvar(x, params, key):
1700
-
1701
- # sample probabilities
1702
- prob, key, error, params = jax_probs(x, params, key)
1703
- if unnorm:
1704
- normalizer = jnp.sum(prob, axis=-1, keepdims=True)
1705
- prob = prob / normalizer
1706
-
1707
- # dispatch to sampling subroutine
1895
+ def _jax_wrapped_distribution_discrete_pvar(fls, nfls, params, key):
1896
+ prob, key, error, params = prob_fn(fls, nfls, params, key)
1708
1897
  key, subkey = random.split(key)
1709
- sample, params = jax_op(subkey, prob, params)
1710
- out_of_bounds = jnp.logical_not(jnp.logical_and(
1711
- jnp.all(prob >= 0),
1712
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1713
- ))
1714
- error |= (out_of_bounds * ERR)
1715
- return sample, key, error, params
1716
-
1898
+ sample = random.categorical(key=subkey, logits=jnp.log(prob), axis=-1)
1899
+ error = JaxRDDLCompiler._jax_update_discrete_oob_error(error, prob)
1900
+ return sample, key, error, params
1717
1901
  return _jax_wrapped_distribution_discrete_pvar
1718
1902
 
1719
1903
  # ===========================================================================
1720
1904
  # random vectors
1721
1905
  # ===========================================================================
1722
1906
 
1723
- def _jax_random_vector(self, expr, init_params):
1907
+ def _jax_random_vector(self, expr, aux):
1724
1908
  _, name = expr.etype
1725
1909
  if name == 'MultivariateNormal':
1726
- return self._jax_multivariate_normal(expr, init_params)
1910
+ return self._jax_multivariate_normal(expr, aux)
1727
1911
  elif name == 'MultivariateStudent':
1728
- return self._jax_multivariate_student(expr, init_params)
1912
+ return self._jax_multivariate_student(expr, aux)
1729
1913
  elif name == 'Dirichlet':
1730
- return self._jax_dirichlet(expr, init_params)
1914
+ return self._jax_dirichlet(expr, aux)
1731
1915
  elif name == 'Multinomial':
1732
- return self._jax_multinomial(expr, init_params)
1916
+ return self._jax_multinomial(expr, aux)
1733
1917
  else:
1734
1918
  raise RDDLNotImplementedError(
1735
1919
  f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1736
1920
 
1737
- def _jax_multivariate_normal(self, expr, init_params):
1921
+ def _jax_multivariate_normal(self, expr, aux):
1738
1922
  _, args = expr.args
1739
1923
  mean, cov = args
1740
- jax_mean = self._jax(mean, init_params)
1741
- jax_cov = self._jax(cov, init_params)
1924
+ jax_mean = self._jax(mean, aux)
1925
+ jax_cov = self._jax(cov, aux)
1742
1926
  index, = self.traced.cached_sim_info(expr)
1743
1927
 
1744
1928
  # reparameterization trick MN(m, LL') = LZ + m, where Z ~ Normal(0, 1)
1745
- def _jax_wrapped_distribution_multivariate_normal(x, params, key):
1929
+ def _jax_wrapped_distribution_multivariate_normal(fls, nfls, params, key):
1746
1930
 
1747
1931
  # sample the mean and covariance
1748
- sample_mean, key, err1, params = jax_mean(x, params, key)
1749
- sample_cov, key, err2, params = jax_cov(x, params, key)
1932
+ sample_mean, key, err1, params = jax_mean(fls, nfls, params, key)
1933
+ sample_cov, key, err2, params = jax_cov(fls, nfls, params, key)
1750
1934
 
1751
1935
  # sample Normal(0, 1)
1752
1936
  key, subkey = random.split(key)
1753
1937
  Z = random.normal(
1754
- key=subkey,
1755
- shape=jnp.shape(sample_mean) + (1,),
1756
- dtype=self.REAL
1757
- )
1938
+ key=subkey, shape=jnp.shape(sample_mean) + (1,), dtype=self.REAL)
1758
1939
 
1759
1940
  # compute L s.t. cov = L * L' and reparameterize
1760
1941
  L = jnp.linalg.cholesky(sample_cov)
1761
1942
  sample = jnp.matmul(L, Z)[..., 0] + sample_mean
1762
1943
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1763
1944
  err = err1 | err2
1764
- return sample, key, err, params
1765
-
1945
+ return sample, key, err, params
1766
1946
  return _jax_wrapped_distribution_multivariate_normal
1767
1947
 
1768
- def _jax_multivariate_student(self, expr, init_params):
1948
+ def _jax_multivariate_student(self, expr, aux):
1769
1949
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTIVARIATE_STUDENT']
1770
1950
 
1771
1951
  _, args = expr.args
1772
1952
  mean, cov, df = args
1773
- jax_mean = self._jax(mean, init_params)
1774
- jax_cov = self._jax(cov, init_params)
1775
- jax_df = self._jax(df, init_params)
1953
+ jax_mean = self._jax(mean, aux)
1954
+ jax_cov = self._jax(cov, aux)
1955
+ jax_df = self._jax(df, aux)
1776
1956
  index, = self.traced.cached_sim_info(expr)
1777
1957
 
1778
1958
  # reparameterization trick MN(m, LL') = LZ + m, where Z ~ StudentT(0, 1)
1779
- def _jax_wrapped_distribution_multivariate_student(x, params, key):
1959
+ def _jax_wrapped_distribution_multivariate_student(fls, nfls, params, key):
1780
1960
 
1781
1961
  # sample the mean and covariance and degrees of freedom
1782
- sample_mean, key, err1, params = jax_mean(x, params, key)
1783
- sample_cov, key, err2, params = jax_cov(x, params, key)
1784
- sample_df, key, err3, params = jax_df(x, params, key)
1962
+ sample_mean, key, err1, params = jax_mean(fls, nfls, params, key)
1963
+ sample_cov, key, err2, params = jax_cov(fls, nfls, params, key)
1964
+ sample_df, key, err3, params = jax_df(fls, nfls, params, key)
1785
1965
  out_of_bounds = jnp.logical_not(jnp.all(sample_df > 0))
1786
1966
 
1787
1967
  # sample StudentT(0, 1, df) -- broadcast df to same shape as cov
@@ -1800,43 +1980,41 @@ class JaxRDDLCompiler:
1800
1980
  sample = jnp.matmul(L, Z)[..., 0] + sample_mean
1801
1981
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1802
1982
  error = err1 | err2 | err3 | (out_of_bounds * ERR)
1803
- return sample, key, error, params
1804
-
1983
+ return sample, key, error, params
1805
1984
  return _jax_wrapped_distribution_multivariate_student
1806
1985
 
1807
- def _jax_dirichlet(self, expr, init_params):
1986
+ def _jax_dirichlet(self, expr, aux):
1808
1987
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DIRICHLET']
1809
1988
 
1810
1989
  _, args = expr.args
1811
1990
  alpha, = args
1812
- jax_alpha = self._jax(alpha, init_params)
1991
+ jax_alpha = self._jax(alpha, aux)
1813
1992
  index, = self.traced.cached_sim_info(expr)
1814
1993
 
1815
1994
  # sample Gamma(alpha_i, 1) and normalize across i
1816
- def _jax_wrapped_distribution_dirichlet(x, params, key):
1817
- alpha, key, error, params = jax_alpha(x, params, key)
1995
+ def _jax_wrapped_distribution_dirichlet(fls, nfls, params, key):
1996
+ alpha, key, error, params = jax_alpha(fls, nfls, params, key)
1818
1997
  out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
1819
- error |= (out_of_bounds * ERR)
1998
+ error = error | (out_of_bounds * ERR)
1820
1999
  key, subkey = random.split(key)
1821
- Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
1822
- sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
2000
+ gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
2001
+ sample = gamma / jnp.sum(gamma, axis=-1, keepdims=True)
1823
2002
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1824
- return sample, key, error, params
1825
-
2003
+ return sample, key, error, params
1826
2004
  return _jax_wrapped_distribution_dirichlet
1827
2005
 
1828
- def _jax_multinomial(self, expr, init_params):
2006
+ def _jax_multinomial(self, expr, aux):
1829
2007
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTINOMIAL']
1830
2008
 
1831
2009
  _, args = expr.args
1832
2010
  trials, prob = args
1833
- jax_trials = self._jax(trials, init_params)
1834
- jax_prob = self._jax(prob, init_params)
2011
+ jax_trials = self._jax(trials, aux)
2012
+ jax_prob = self._jax(prob, aux)
1835
2013
  index, = self.traced.cached_sim_info(expr)
1836
2014
 
1837
- def _jax_wrapped_distribution_multinomial(x, params, key):
1838
- trials, key, err1, params = jax_trials(x, params, key)
1839
- prob, key, err2, params = jax_prob(x, params, key)
2015
+ def _jax_wrapped_distribution_multinomial(fls, nfls, params, key):
2016
+ trials, key, err1, params = jax_trials(fls, nfls, params, key)
2017
+ prob, key, err2, params = jax_prob(fls, nfls, params, key)
1840
2018
  trials = jnp.asarray(trials, dtype=self.REAL)
1841
2019
  prob = jnp.asarray(prob, dtype=self.REAL)
1842
2020
  key, subkey = random.split(key)
@@ -1844,70 +2022,66 @@ class JaxRDDLCompiler:
1844
2022
  sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1845
2023
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1846
2024
  out_of_bounds = jnp.logical_not(jnp.all(
1847
- (prob >= 0)
1848
- & jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
1849
- & (trials >= 0)
2025
+ jnp.logical_and(
2026
+ jnp.logical_and(prob >= 0, jnp.allclose(jnp.sum(prob, axis=-1), 1.)),
2027
+ trials >= 0
2028
+ )
1850
2029
  ))
1851
2030
  error = err1 | err2 | (out_of_bounds * ERR)
1852
- return sample, key, error, params
1853
-
2031
+ return sample, key, error, params
1854
2032
  return _jax_wrapped_distribution_multinomial
1855
2033
 
1856
2034
  # ===========================================================================
1857
2035
  # matrix algebra
1858
2036
  # ===========================================================================
1859
2037
 
1860
- def _jax_matrix(self, expr, init_params):
2038
+ def _jax_matrix(self, expr, aux):
1861
2039
  _, op = expr.etype
1862
2040
  if op == 'det':
1863
- return self._jax_matrix_det(expr, init_params)
2041
+ return self._jax_matrix_det(expr, aux)
1864
2042
  elif op == 'inverse':
1865
- return self._jax_matrix_inv(expr, init_params, pseudo=False)
2043
+ return self._jax_matrix_inv(expr, aux, pseudo=False)
1866
2044
  elif op == 'pinverse':
1867
- return self._jax_matrix_inv(expr, init_params, pseudo=True)
2045
+ return self._jax_matrix_inv(expr, aux, pseudo=True)
1868
2046
  elif op == 'cholesky':
1869
- return self._jax_matrix_cholesky(expr, init_params)
2047
+ return self._jax_matrix_cholesky(expr, aux)
1870
2048
  else:
1871
2049
  raise RDDLNotImplementedError(
1872
- f'Matrix operation {op} is not supported.\n' +
1873
- print_stack_trace(expr))
2050
+ f'Matrix operation {op} is not supported.\n' + print_stack_trace(expr))
1874
2051
 
1875
- def _jax_matrix_det(self, expr, init_params):
1876
- * _, arg = expr.args
1877
- jax_arg = self._jax(arg, init_params)
2052
+ def _jax_matrix_det(self, expr, aux):
2053
+ arg = expr.args[-1]
2054
+ jax_arg = self._jax(arg, aux)
1878
2055
 
1879
- def _jax_wrapped_matrix_operation_det(x, params, key):
1880
- sample_arg, key, error, params = jax_arg(x, params, key)
2056
+ def _jax_wrapped_matrix_operation_det(fls, nfls, params, key):
2057
+ sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
1881
2058
  sample = jnp.linalg.det(sample_arg)
1882
- return sample, key, error, params
1883
-
2059
+ return sample, key, error, params
1884
2060
  return _jax_wrapped_matrix_operation_det
1885
2061
 
1886
- def _jax_matrix_inv(self, expr, init_params, pseudo):
2062
+ def _jax_matrix_inv(self, expr, aux, pseudo):
1887
2063
  _, arg = expr.args
1888
- jax_arg = self._jax(arg, init_params)
2064
+ jax_arg = self._jax(arg, aux)
1889
2065
  indices = self.traced.cached_sim_info(expr)
1890
2066
  op = jnp.linalg.pinv if pseudo else jnp.linalg.inv
1891
2067
 
1892
- def _jax_wrapped_matrix_operation_inv(x, params, key):
1893
- sample_arg, key, error, params = jax_arg(x, params, key)
2068
+ def _jax_wrapped_matrix_operation_inv(fls, nfls, params, key):
2069
+ sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
1894
2070
  sample = op(sample_arg)
1895
2071
  sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
1896
- return sample, key, error, params
1897
-
2072
+ return sample, key, error, params
1898
2073
  return _jax_wrapped_matrix_operation_inv
1899
2074
 
1900
- def _jax_matrix_cholesky(self, expr, init_params):
2075
+ def _jax_matrix_cholesky(self, expr, aux):
1901
2076
  _, arg = expr.args
1902
- jax_arg = self._jax(arg, init_params)
2077
+ jax_arg = self._jax(arg, aux)
1903
2078
  indices = self.traced.cached_sim_info(expr)
1904
2079
  op = jnp.linalg.cholesky
1905
2080
 
1906
- def _jax_wrapped_matrix_operation_cholesky(x, params, key):
1907
- sample_arg, key, error, params = jax_arg(x, params, key)
2081
+ def _jax_wrapped_matrix_operation_cholesky(fls, nfls, params, key):
2082
+ sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
1908
2083
  sample = op(sample_arg)
1909
2084
  sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
1910
- return sample, key, error, params
1911
-
2085
+ return sample, key, error, params
1912
2086
  return _jax_wrapped_matrix_operation_cholesky
1913
-
2087
+