pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
from functools import partial
|
|
17
|
+
import termcolor
|
|
17
18
|
import traceback
|
|
18
19
|
from typing import Any, Callable, Dict, List, Optional
|
|
19
20
|
|
|
20
21
|
import jax
|
|
21
22
|
import jax.numpy as jnp
|
|
22
23
|
import jax.random as random
|
|
24
|
+
import jax.scipy as scipy
|
|
23
25
|
|
|
24
26
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
25
27
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -36,8 +38,6 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
36
38
|
from pyRDDLGym.core.debug.logger import Logger
|
|
37
39
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
38
40
|
|
|
39
|
-
from pyRDDLGym_jax.core.logic import ExactLogic
|
|
40
|
-
|
|
41
41
|
# more robust approach - if user does not have this or broken try to continue
|
|
42
42
|
try:
|
|
43
43
|
from tensorflow_probability.substrates import jax as tfp
|
|
@@ -53,12 +53,11 @@ class JaxRDDLCompiler:
|
|
|
53
53
|
All operations are identical to their numpy equivalents.
|
|
54
54
|
'''
|
|
55
55
|
|
|
56
|
-
def __init__(self, rddl: RDDLLiftedModel,
|
|
56
|
+
def __init__(self, rddl: RDDLLiftedModel, *args,
|
|
57
57
|
allow_synchronous_state: bool=True,
|
|
58
58
|
logger: Optional[Logger]=None,
|
|
59
59
|
use64bit: bool=False,
|
|
60
|
-
|
|
61
|
-
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
60
|
+
python_functions: Optional[Dict[str, Callable]]=None, **kwargs) -> None:
|
|
62
61
|
'''Creates a new RDDL to Jax compiler.
|
|
63
62
|
|
|
64
63
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -66,10 +65,17 @@ class JaxRDDLCompiler:
|
|
|
66
65
|
on each other
|
|
67
66
|
:param logger: to log information about compilation to file
|
|
68
67
|
:param use64bit: whether to use 64 bit arithmetic
|
|
69
|
-
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
70
|
-
are always compiled using exact JAX expressions
|
|
71
68
|
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
72
69
|
'''
|
|
70
|
+
|
|
71
|
+
# warn about unused parameters
|
|
72
|
+
if args:
|
|
73
|
+
print(termcolor.colored(
|
|
74
|
+
f'[WARN] JaxRDDLCompiler received invalid args {args}.', 'yellow'))
|
|
75
|
+
if kwargs:
|
|
76
|
+
print(termcolor.colored(
|
|
77
|
+
f'[WARN] JaxRDDLCompiler received invalid kwargs {kwargs}.', 'yellow'))
|
|
78
|
+
|
|
73
79
|
self.rddl = rddl
|
|
74
80
|
self.logger = logger
|
|
75
81
|
# jax.config.update('jax_log_compiles', True) # for testing ONLY
|
|
@@ -86,7 +92,7 @@ class JaxRDDLCompiler:
|
|
|
86
92
|
self.JAX_TYPES = {
|
|
87
93
|
'int': self.INT,
|
|
88
94
|
'real': self.REAL,
|
|
89
|
-
'bool':
|
|
95
|
+
'bool': jnp.bool_
|
|
90
96
|
}
|
|
91
97
|
|
|
92
98
|
# compile initial values
|
|
@@ -94,6 +100,7 @@ class JaxRDDLCompiler:
|
|
|
94
100
|
self.init_values = initializer.initialize()
|
|
95
101
|
|
|
96
102
|
# compute dependency graph for CPFs and sort them by evaluation order
|
|
103
|
+
self.allow_synchronous_state = allow_synchronous_state
|
|
97
104
|
sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
|
|
98
105
|
self.levels = sorter.compute_levels()
|
|
99
106
|
|
|
@@ -101,10 +108,12 @@ class JaxRDDLCompiler:
|
|
|
101
108
|
tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
|
|
102
109
|
self.traced = tracer.trace()
|
|
103
110
|
|
|
104
|
-
#
|
|
111
|
+
# external python functions
|
|
105
112
|
if python_functions is None:
|
|
106
113
|
python_functions = {}
|
|
107
114
|
self.python_functions = python_functions
|
|
115
|
+
|
|
116
|
+
# extract the box constraints on actions
|
|
108
117
|
simulator = RDDLSimulatorPrecompiled(
|
|
109
118
|
rddl=self.rddl,
|
|
110
119
|
init_values=self.init_values,
|
|
@@ -112,75 +121,90 @@ class JaxRDDLCompiler:
|
|
|
112
121
|
trace_info=self.traced,
|
|
113
122
|
python_functions=python_functions
|
|
114
123
|
)
|
|
115
|
-
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
124
|
+
self.constraints = RDDLConstraints(simulator, vectorized=True)
|
|
125
|
+
|
|
126
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
127
|
+
'''Returns a dictionary of configurable parameter name: parameter value pairs.
|
|
128
|
+
'''
|
|
129
|
+
return {
|
|
130
|
+
'allow_synchronous_state': self.allow_synchronous_state,
|
|
131
|
+
'use64bit': self.use64bit,
|
|
132
|
+
'python_functions': self.python_functions
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
def split_fluent_nonfluent(self, values):
|
|
136
|
+
'''Splits given values dictionary into fluent and non-fluent dictionaries.
|
|
137
|
+
'''
|
|
138
|
+
nonfluents = self.rddl.non_fluents
|
|
139
|
+
fls = {name: value for (name, value) in values.items() if name not in nonfluents}
|
|
140
|
+
nfls = {name: value for (name, value) in values.items() if name in nonfluents}
|
|
141
|
+
return fls, nfls
|
|
123
142
|
|
|
124
143
|
# ===========================================================================
|
|
125
144
|
# main compilation subroutines
|
|
126
145
|
# ===========================================================================
|
|
127
146
|
|
|
128
|
-
def compile(self, log_jax_expr: bool=False,
|
|
147
|
+
def compile(self, log_jax_expr: bool=False,
|
|
148
|
+
heading: str='',
|
|
149
|
+
extra_aux: Dict[str, Any]={}) -> None:
|
|
129
150
|
'''Compiles the current RDDL into Jax expressions.
|
|
130
151
|
|
|
131
152
|
:param log_jax_expr: whether to pretty-print the compiled Jax functions
|
|
132
153
|
to the log file
|
|
133
154
|
:param heading: the heading to print before compilation information
|
|
155
|
+
:param extra_aux: extra info to save during compilations
|
|
134
156
|
'''
|
|
135
|
-
|
|
136
|
-
self.
|
|
137
|
-
|
|
138
|
-
self.
|
|
139
|
-
self.
|
|
140
|
-
self.
|
|
141
|
-
self.
|
|
142
|
-
|
|
157
|
+
self.model_aux = {'params': {}, 'overriden': {}}
|
|
158
|
+
self.model_aux.update(extra_aux)
|
|
159
|
+
|
|
160
|
+
self.invariants = self._compile_constraints(self.rddl.invariants, self.model_aux)
|
|
161
|
+
self.preconditions = self._compile_constraints(self.rddl.preconditions, self.model_aux)
|
|
162
|
+
self.terminations = self._compile_constraints(self.rddl.terminations, self.model_aux)
|
|
163
|
+
self.cpfs = self._compile_cpfs(self.model_aux)
|
|
164
|
+
self.reward = self._compile_reward(self.model_aux)
|
|
165
|
+
|
|
166
|
+
# add compiled jax expression to logger
|
|
143
167
|
if log_jax_expr and self.logger is not None:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
168
|
+
self._log_printed_jax(heading)
|
|
169
|
+
|
|
170
|
+
def _log_printed_jax(self, heading=''):
|
|
171
|
+
printed = self.print_jax()
|
|
172
|
+
printed_cpfs = '\n\n'.join(f'{k}: {v}' for (k, v) in printed['cpfs'].items())
|
|
173
|
+
printed_reward = printed['reward']
|
|
174
|
+
printed_invariants = '\n\n'.join(v for v in printed['invariants'])
|
|
175
|
+
printed_preconds = '\n\n'.join(v for v in printed['preconditions'])
|
|
176
|
+
printed_terminals = '\n\n'.join(v for v in printed['terminations'])
|
|
177
|
+
printed_params = '\n'.join(f'{k}: {v}' for (k, v) in self.model_aux['params'].items())
|
|
178
|
+
self.logger.log(
|
|
179
|
+
f'[info] {heading}\n'
|
|
180
|
+
f'[info] compiled JAX CPFs:\n\n'
|
|
181
|
+
f'{printed_cpfs}\n\n'
|
|
182
|
+
f'[info] compiled JAX reward:\n\n'
|
|
183
|
+
f'{printed_reward}\n\n'
|
|
184
|
+
f'[info] compiled JAX invariants:\n\n'
|
|
185
|
+
f'{printed_invariants}\n\n'
|
|
186
|
+
f'[info] compiled JAX preconditions:\n\n'
|
|
187
|
+
f'{printed_preconds}\n\n'
|
|
188
|
+
f'[info] compiled JAX terminations:\n\n'
|
|
189
|
+
f'{printed_terminals}\n'
|
|
190
|
+
f'[info] model parameters:\n'
|
|
191
|
+
f'{printed_params}\n'
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _compile_constraints(self, constraints, aux):
|
|
195
|
+
return [self._jax(expr, aux, dtype=jnp.bool_) for expr in constraints]
|
|
196
|
+
|
|
197
|
+
def _compile_cpfs(self, aux):
|
|
173
198
|
jax_cpfs = {}
|
|
174
199
|
for cpfs in self.levels.values():
|
|
175
200
|
for cpf in cpfs:
|
|
176
201
|
_, expr = self.rddl.cpfs[cpf]
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=dtype)
|
|
202
|
+
dtype = self.JAX_TYPES.get(self.rddl.variable_ranges[cpf], self.INT)
|
|
203
|
+
jax_cpfs[cpf] = self._jax(expr, aux, dtype=dtype)
|
|
180
204
|
return jax_cpfs
|
|
181
205
|
|
|
182
|
-
def _compile_reward(self,
|
|
183
|
-
return self._jax(self.rddl.reward,
|
|
206
|
+
def _compile_reward(self, aux):
|
|
207
|
+
return self._jax(self.rddl.reward, aux, dtype=self.REAL)
|
|
184
208
|
|
|
185
209
|
def _extract_inequality_constraint(self, expr):
|
|
186
210
|
result = []
|
|
@@ -208,55 +232,110 @@ class JaxRDDLCompiler:
|
|
|
208
232
|
result.extend(self._extract_equality_constraint(arg))
|
|
209
233
|
return result
|
|
210
234
|
|
|
211
|
-
def _jax_nonlinear_constraints(self,
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
for (i, expr) in enumerate(rddl.preconditions)
|
|
232
|
-
for constr in self._extract_equality_constraint(expr)
|
|
233
|
-
if not self.constraints.is_box_preconditions[i]]
|
|
234
|
-
|
|
235
|
-
# compile them to JAX and write as g(s, a) == 0
|
|
236
|
-
jax_equalities = []
|
|
237
|
-
for (left, right) in equalities:
|
|
238
|
-
jax_lhs = self._jax(left, init_params)
|
|
239
|
-
jax_rhs = self._jax(right, init_params)
|
|
240
|
-
jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
|
|
241
|
-
jax_equalities.append(jax_constr)
|
|
242
|
-
|
|
235
|
+
def _jax_nonlinear_constraints(self, aux):
|
|
236
|
+
jax_equalities, jax_inequalities = [], []
|
|
237
|
+
for (i, expr) in enumerate(self.rddl.preconditions):
|
|
238
|
+
if not self.constraints.is_box_preconditions[i]:
|
|
239
|
+
|
|
240
|
+
# compile inequalities to JAX and write as h(s, a) <= 0
|
|
241
|
+
for (left, right) in self._extract_inequality_constraint(expr):
|
|
242
|
+
jax_lhs = self._jax(left, aux)
|
|
243
|
+
jax_rhs = self._jax(right, aux)
|
|
244
|
+
jax_constr = self._jax_binary(
|
|
245
|
+
jax_lhs, jax_rhs, jnp.subtract, at_least_int=True)
|
|
246
|
+
jax_inequalities.append(jax_constr)
|
|
247
|
+
|
|
248
|
+
# compile equalities to JAX and write as g(s, a) == 0
|
|
249
|
+
for (left, right) in self._extract_equality_constraint(expr):
|
|
250
|
+
jax_lhs = self._jax(left, aux)
|
|
251
|
+
jax_rhs = self._jax(right, aux)
|
|
252
|
+
jax_constr = self._jax_binary(
|
|
253
|
+
jax_lhs, jax_rhs, jnp.subtract, at_least_int=True)
|
|
254
|
+
jax_equalities.append(jax_constr)
|
|
243
255
|
return jax_inequalities, jax_equalities
|
|
244
256
|
|
|
257
|
+
def _jax_preconditions(self):
|
|
258
|
+
preconds = self.preconditions
|
|
259
|
+
def _jax_wrapped_preconditions(key, errors, fls, nfls, params):
|
|
260
|
+
precond_check = jnp.array(True, dtype=jnp.bool_)
|
|
261
|
+
for precond in preconds:
|
|
262
|
+
sample, key, err, params = precond(fls, nfls, params, key)
|
|
263
|
+
precond_check = jnp.logical_and(precond_check, sample)
|
|
264
|
+
errors = errors | err
|
|
265
|
+
return precond_check, key, errors, params
|
|
266
|
+
return _jax_wrapped_preconditions
|
|
267
|
+
|
|
268
|
+
def _jax_inequalities(self, aux_constr):
|
|
269
|
+
inequality_fns, equality_fns = self._jax_nonlinear_constraints(aux_constr)
|
|
270
|
+
def _jax_wrapped_inequalities(key, errors, fls, nfls, params):
|
|
271
|
+
inequalities, equalities = [], []
|
|
272
|
+
for constraint in inequality_fns:
|
|
273
|
+
sample, key, err, params = constraint(fls, nfls, params, key)
|
|
274
|
+
inequalities.append(sample)
|
|
275
|
+
errors = errors | err
|
|
276
|
+
for constraint in equality_fns:
|
|
277
|
+
sample, key, err, params = constraint(fls, nfls, params, key)
|
|
278
|
+
equalities.append(sample)
|
|
279
|
+
errors = errors | err
|
|
280
|
+
return (inequalities, equalities), key, errors, params
|
|
281
|
+
return _jax_wrapped_inequalities
|
|
282
|
+
|
|
283
|
+
def _jax_cpfs(self):
|
|
284
|
+
cpfs = self.cpfs
|
|
285
|
+
def _jax_wrapped_cpfs(key, errors, fls, nfls, params):
|
|
286
|
+
fls = fls.copy()
|
|
287
|
+
for (name, cpf) in cpfs.items():
|
|
288
|
+
fls[name], key, err, params = cpf(fls, nfls, params, key)
|
|
289
|
+
errors = errors | err
|
|
290
|
+
return fls, key, errors, params
|
|
291
|
+
return _jax_wrapped_cpfs
|
|
292
|
+
|
|
293
|
+
def _jax_reward(self):
|
|
294
|
+
reward_fn = self.reward
|
|
295
|
+
def _jax_wrapped_reward(key, errors, fls, nfls, params):
|
|
296
|
+
reward, key, err, params = reward_fn(fls, nfls, params, key)
|
|
297
|
+
errors = errors | err
|
|
298
|
+
return reward, key, errors, params
|
|
299
|
+
return _jax_wrapped_reward
|
|
300
|
+
|
|
301
|
+
def _jax_invariants(self):
|
|
302
|
+
invariants = self.invariants
|
|
303
|
+
def _jax_wrapped_invariants(key, errors, fls, nfls, params):
|
|
304
|
+
invariant_check = jnp.array(True, dtype=jnp.bool_)
|
|
305
|
+
for invariant in invariants:
|
|
306
|
+
sample, key, err, params = invariant(fls, nfls, params, key)
|
|
307
|
+
invariant_check = jnp.logical_and(invariant_check, sample)
|
|
308
|
+
errors = errors | err
|
|
309
|
+
return invariant_check, key, errors, params
|
|
310
|
+
return _jax_wrapped_invariants
|
|
311
|
+
|
|
312
|
+
def _jax_terminations(self):
|
|
313
|
+
terminations = self.terminations
|
|
314
|
+
def _jax_wrapped_terminations(key, errors, fls, nfls, params):
|
|
315
|
+
terminated_check = jnp.array(False, dtype=jnp.bool_)
|
|
316
|
+
for terminal in terminations:
|
|
317
|
+
sample, key, err, params = terminal(fls, nfls, params, key)
|
|
318
|
+
terminated_check = jnp.logical_or(terminated_check, sample)
|
|
319
|
+
errors = errors | err
|
|
320
|
+
return terminated_check, key, errors, params
|
|
321
|
+
return _jax_wrapped_terminations
|
|
322
|
+
|
|
245
323
|
def compile_transition(self, check_constraints: bool=False,
|
|
246
324
|
constraint_func: bool=False,
|
|
247
|
-
|
|
248
|
-
|
|
325
|
+
cache_path_info: bool=False,
|
|
326
|
+
aux_constr: Dict[str, Any]={}) -> Callable:
|
|
249
327
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
250
328
|
samples the next state.
|
|
251
329
|
|
|
252
330
|
The arguments of the returned function is:
|
|
253
331
|
- key is the PRNG key
|
|
254
332
|
- actions is the dict of action tensors
|
|
255
|
-
-
|
|
256
|
-
-
|
|
333
|
+
- fls is the dict of current fluent pvar tensors
|
|
334
|
+
- nfls is the dict of nonfluent pvar tensors
|
|
335
|
+
- params is a dict of parameters for the relaxed model.
|
|
257
336
|
|
|
258
337
|
The returned value of the function is:
|
|
259
|
-
-
|
|
338
|
+
- fls is the returned next epoch fluent values
|
|
260
339
|
- log includes all the auxiliary information about constraints
|
|
261
340
|
satisfied, errors, etc.
|
|
262
341
|
|
|
@@ -284,104 +363,118 @@ class JaxRDDLCompiler:
|
|
|
284
363
|
in addition to the usual outputs
|
|
285
364
|
:param cache_path_info: whether to save full path traces as part of the log
|
|
286
365
|
'''
|
|
287
|
-
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
288
|
-
rddl = self.rddl
|
|
289
|
-
reward_fn, cpfs, preconds, invariants, terminals = \
|
|
290
|
-
self.reward, self.cpfs, self.preconditions, self.invariants, self.terminations
|
|
366
|
+
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
291
367
|
|
|
292
|
-
# compile
|
|
368
|
+
# compile all components of the RDDL
|
|
369
|
+
cpf_fn = self._jax_cpfs()
|
|
370
|
+
reward_fn = self._jax_reward()
|
|
371
|
+
|
|
372
|
+
# compile optional constraints
|
|
373
|
+
precond_fn = invariant_fn = terminal_fn = None
|
|
374
|
+
if check_constraints:
|
|
375
|
+
precond_fn = self._jax_preconditions()
|
|
376
|
+
invariant_fn = self._jax_invariants()
|
|
377
|
+
terminal_fn = self._jax_terminations()
|
|
378
|
+
|
|
379
|
+
# compile optional inequalities
|
|
380
|
+
ineq_fn = None
|
|
293
381
|
if constraint_func:
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
else:
|
|
297
|
-
inequality_fns, equality_fns = None, None
|
|
298
|
-
|
|
382
|
+
ineq_fn = self._jax_inequalities(aux_constr)
|
|
383
|
+
|
|
299
384
|
# do a single step update from the RDDL model
|
|
300
|
-
def _jax_wrapped_single_step(key, actions,
|
|
385
|
+
def _jax_wrapped_single_step(key, actions, fls, nfls, params):
|
|
301
386
|
errors = NORMAL
|
|
302
|
-
|
|
387
|
+
|
|
388
|
+
fls = fls.copy()
|
|
389
|
+
fls.update(actions)
|
|
303
390
|
|
|
304
391
|
# check action preconditions
|
|
305
|
-
precond_check = True
|
|
306
392
|
if check_constraints:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
errors |= err
|
|
393
|
+
precond, key, errors, params = precond_fn(key, errors, fls, nfls, params)
|
|
394
|
+
else:
|
|
395
|
+
precond = jnp.array(True, dtype=jnp.bool_)
|
|
311
396
|
|
|
312
397
|
# compute h(s, a) <= 0 and g(s, a) == 0 constraint functions
|
|
313
|
-
inequalities, equalities = [], []
|
|
314
398
|
if constraint_func:
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
for constraint in equality_fns:
|
|
320
|
-
sample, key, err, model_params = constraint(subs, model_params, key)
|
|
321
|
-
equalities.append(sample)
|
|
322
|
-
errors |= err
|
|
399
|
+
(inequalities, equalities), key, errors, params = ineq_fn(
|
|
400
|
+
key, errors, fls, nfls, params)
|
|
401
|
+
else:
|
|
402
|
+
inequalities, equalities = [], []
|
|
323
403
|
|
|
324
404
|
# calculate CPFs in topological order
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
errors |= err
|
|
328
|
-
|
|
329
|
-
# calculate the immediate reward
|
|
330
|
-
reward, key, err, model_params = reward_fn(subs, model_params, key)
|
|
331
|
-
errors |= err
|
|
405
|
+
fls, key, errors, params = cpf_fn(key, errors, fls, nfls, params)
|
|
406
|
+
fluents = fls if cache_path_info else {}
|
|
332
407
|
|
|
333
|
-
# calculate
|
|
334
|
-
|
|
335
|
-
fluents = {name: values for (name, values) in subs.items()
|
|
336
|
-
if name not in rddl.non_fluents}
|
|
337
|
-
else:
|
|
338
|
-
fluents = {}
|
|
408
|
+
# calculate the immediate reward
|
|
409
|
+
reward, key, errors, params = reward_fn(key, errors, fls, nfls, params)
|
|
339
410
|
|
|
340
411
|
# set the next state to the current state
|
|
341
|
-
for (state, next_state) in rddl.next_state.items():
|
|
342
|
-
|
|
412
|
+
for (state, next_state) in self.rddl.next_state.items():
|
|
413
|
+
fls[state] = fls[next_state]
|
|
343
414
|
|
|
344
|
-
# check the state invariants
|
|
345
|
-
invariant_check = True
|
|
415
|
+
# check the state invariants and termination
|
|
346
416
|
if check_constraints:
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
terminated_check = False
|
|
354
|
-
if check_constraints:
|
|
355
|
-
for terminal in terminals:
|
|
356
|
-
sample, key, err, model_params = terminal(subs, model_params, key)
|
|
357
|
-
terminated_check = jnp.logical_or(terminated_check, sample)
|
|
358
|
-
errors |= err
|
|
359
|
-
|
|
417
|
+
invariant, key, errors, params = invariant_fn(key, errors, fls, nfls, params)
|
|
418
|
+
terminated, key, errors, params = terminal_fn(key, errors, fls, nfls, params)
|
|
419
|
+
else:
|
|
420
|
+
invariant = jnp.array(True, dtype=jnp.bool_)
|
|
421
|
+
terminated = jnp.array(False, dtype=jnp.bool_)
|
|
422
|
+
|
|
360
423
|
# prepare the return value
|
|
361
424
|
log = {
|
|
362
425
|
'fluents': fluents,
|
|
363
426
|
'reward': reward,
|
|
364
427
|
'error': errors,
|
|
365
|
-
'precondition':
|
|
366
|
-
'invariant':
|
|
367
|
-
'termination':
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
return subs, log, model_params
|
|
428
|
+
'precondition': precond,
|
|
429
|
+
'invariant': invariant,
|
|
430
|
+
'termination': terminated,
|
|
431
|
+
'inequalities': inequalities,
|
|
432
|
+
'equalities': equalities
|
|
433
|
+
}
|
|
434
|
+
return fls, log, params
|
|
374
435
|
|
|
375
436
|
return _jax_wrapped_single_step
|
|
376
437
|
|
|
438
|
+
def _compile_policy_step(self, policy, transition_fn):
|
|
439
|
+
def _jax_wrapped_policy_step(key, policy_params, hyperparams, step, fls, nfls,
|
|
440
|
+
model_params):
|
|
441
|
+
key, subkey = random.split(key)
|
|
442
|
+
actions = policy(key, policy_params, hyperparams, step, fls)
|
|
443
|
+
return transition_fn(subkey, actions, fls, nfls, model_params)
|
|
444
|
+
return _jax_wrapped_policy_step
|
|
445
|
+
|
|
446
|
+
def _compile_batched_policy_step(self, policy_step_fn, n_batch, model_params_reduction):
|
|
447
|
+
def _jax_wrapped_batched_policy_step(carry, step):
|
|
448
|
+
key, policy_params, hyperparams, fls, nfls, model_params = carry
|
|
449
|
+
keys = random.split(key, num=1 + n_batch)
|
|
450
|
+
key, subkeys = keys[0], keys[1:]
|
|
451
|
+
fls, log, model_params = jax.vmap(
|
|
452
|
+
policy_step_fn, in_axes=(0, None, None, None, 0, None, None)
|
|
453
|
+
)(subkeys, policy_params, hyperparams, step, fls, nfls, model_params)
|
|
454
|
+
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
455
|
+
carry = (key, policy_params, hyperparams, fls, nfls, model_params)
|
|
456
|
+
return carry, log
|
|
457
|
+
return _jax_wrapped_batched_policy_step
|
|
458
|
+
|
|
459
|
+
def _compile_unrolled_policy_step(self, batched_policy_step_fn, n_steps):
|
|
460
|
+
def _jax_wrapped_batched_policy_rollout(key, policy_params, hyperparams, fls, nfls,
|
|
461
|
+
model_params):
|
|
462
|
+
start = (key, policy_params, hyperparams, fls, nfls, model_params)
|
|
463
|
+
steps = jnp.arange(n_steps)
|
|
464
|
+
end, log = jax.lax.scan(batched_policy_step_fn, start, steps)
|
|
465
|
+
log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
466
|
+
model_params = end[-1]
|
|
467
|
+
return log, model_params
|
|
468
|
+
return _jax_wrapped_batched_policy_rollout
|
|
469
|
+
|
|
377
470
|
def compile_rollouts(self, policy: Callable,
|
|
378
471
|
n_steps: int,
|
|
379
472
|
n_batch: int,
|
|
380
473
|
check_constraints: bool=False,
|
|
381
474
|
constraint_func: bool=False,
|
|
382
|
-
|
|
475
|
+
cache_path_info: bool=False,
|
|
383
476
|
model_params_reduction: Callable=lambda x: x[0],
|
|
384
|
-
|
|
477
|
+
aux_constr: Dict[str, Any]={}) -> Callable:
|
|
385
478
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
386
479
|
samples trajectories with a fixed horizon from a policy.
|
|
387
480
|
|
|
@@ -389,7 +482,8 @@ class JaxRDDLCompiler:
|
|
|
389
482
|
- key is the PRNG key (used by a stochastic policy)
|
|
390
483
|
- policy_params is a pytree of trainable policy weights
|
|
391
484
|
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
392
|
-
-
|
|
485
|
+
- fls is the dictionary of current fluent tensor values
|
|
486
|
+
- nfls is the dictionary of next step fluent tensor value
|
|
393
487
|
- model_params is a dict of model hyperparameters.
|
|
394
488
|
|
|
395
489
|
The returned value of the returned function is:
|
|
@@ -402,7 +496,7 @@ class JaxRDDLCompiler:
|
|
|
402
496
|
- params is a pytree of trainable policy weights
|
|
403
497
|
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
404
498
|
- step is the time index of the decision in the current rollout
|
|
405
|
-
-
|
|
499
|
+
- fls is a dict of fluent tensors for the current epoch.
|
|
406
500
|
|
|
407
501
|
:param policy: a Jax compiled function for the policy as described above
|
|
408
502
|
decision epoch, state dict, and an RNG key and returns an action dict
|
|
@@ -413,54 +507,16 @@ class JaxRDDLCompiler:
|
|
|
413
507
|
returned log and does not raise an exception
|
|
414
508
|
:param constraint_func: produces the h(s, a) constraint function
|
|
415
509
|
in addition to the usual outputs
|
|
510
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
416
511
|
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
417
512
|
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
418
|
-
:param cache_path_info: whether to save full path traces as part of the log
|
|
419
513
|
'''
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
observed_vars = rddl.observ_fluents
|
|
427
|
-
else:
|
|
428
|
-
observed_vars = rddl.state_fluents
|
|
429
|
-
|
|
430
|
-
# evaluate the step from the policy
|
|
431
|
-
def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
|
|
432
|
-
step, subs, model_params):
|
|
433
|
-
states = {var: values
|
|
434
|
-
for (var, values) in subs.items()
|
|
435
|
-
if var in observed_vars}
|
|
436
|
-
actions = policy(key, policy_params, hyperparams, step, states)
|
|
437
|
-
key, subkey = random.split(key)
|
|
438
|
-
return jax_step_fn(subkey, actions, subs, model_params)
|
|
439
|
-
|
|
440
|
-
# do a batched step update from the policy
|
|
441
|
-
def _jax_wrapped_batched_step_policy(carry, step):
|
|
442
|
-
key, policy_params, hyperparams, subs, model_params = carry
|
|
443
|
-
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
444
|
-
keys = jnp.asarray(subkeys)
|
|
445
|
-
subs, log, model_params = jax.vmap(
|
|
446
|
-
_jax_wrapped_single_step_policy,
|
|
447
|
-
in_axes=(0, None, None, None, 0, None)
|
|
448
|
-
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
449
|
-
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
450
|
-
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
451
|
-
return carry, log
|
|
452
|
-
|
|
453
|
-
# do a batched roll-out from the policy
|
|
454
|
-
def _jax_wrapped_batched_rollout(key, policy_params, hyperparams,
|
|
455
|
-
subs, model_params):
|
|
456
|
-
start = (key, policy_params, hyperparams, subs, model_params)
|
|
457
|
-
steps = jnp.arange(n_steps)
|
|
458
|
-
end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
|
|
459
|
-
log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
460
|
-
model_params = end[-1]
|
|
461
|
-
return log, model_params
|
|
462
|
-
|
|
463
|
-
return _jax_wrapped_batched_rollout
|
|
514
|
+
jax_fn = self.compile_transition(
|
|
515
|
+
check_constraints, constraint_func, cache_path_info, aux_constr)
|
|
516
|
+
jax_fn = self._compile_policy_step(policy, jax_fn)
|
|
517
|
+
jax_fn = self._compile_batched_policy_step(jax_fn, n_batch, model_params_reduction)
|
|
518
|
+
jax_fn = self._compile_unrolled_policy_step(jax_fn, n_steps)
|
|
519
|
+
return jax_fn
|
|
464
520
|
|
|
465
521
|
# ===========================================================================
|
|
466
522
|
# error checks and prints
|
|
@@ -470,43 +526,59 @@ class JaxRDDLCompiler:
|
|
|
470
526
|
'''Returns a dictionary containing the string representations of all
|
|
471
527
|
Jax compiled expressions from the RDDL file.
|
|
472
528
|
'''
|
|
473
|
-
|
|
474
|
-
|
|
529
|
+
fls, nfls = self.split_fluent_nonfluent(self.init_values)
|
|
530
|
+
params = self.model_aux['params']
|
|
475
531
|
key = jax.random.PRNGKey(42)
|
|
476
532
|
printed = {
|
|
477
|
-
'cpfs': {
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
'
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
533
|
+
'cpfs': {
|
|
534
|
+
name: str(jax.make_jaxpr(expr)(fls, nfls, params, key))
|
|
535
|
+
for (name, expr) in self.cpfs.items()
|
|
536
|
+
},
|
|
537
|
+
'reward': str(jax.make_jaxpr(self.reward)(fls, nfls, params, key)),
|
|
538
|
+
'invariants': [
|
|
539
|
+
str(jax.make_jaxpr(expr)(fls, nfls, params, key))
|
|
540
|
+
for expr in self.invariants
|
|
541
|
+
],
|
|
542
|
+
'preconditions': [
|
|
543
|
+
str(jax.make_jaxpr(expr)(fls, nfls, params, key))
|
|
544
|
+
for expr in self.preconditions
|
|
545
|
+
],
|
|
546
|
+
'terminations': [
|
|
547
|
+
str(jax.make_jaxpr(expr)(fls, nfls, params, key))
|
|
548
|
+
for expr in self.terminations
|
|
549
|
+
]
|
|
486
550
|
}
|
|
487
551
|
return printed
|
|
488
552
|
|
|
489
553
|
def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
|
|
490
554
|
'''Returns a dictionary of additional information about model parameters.'''
|
|
491
555
|
result = {}
|
|
492
|
-
for (id, value) in self.
|
|
493
|
-
|
|
494
|
-
expr = self.traced.lookup(expr_id)
|
|
556
|
+
for (id, value) in self.model_aux['params'].items():
|
|
557
|
+
expr = self.traced.lookup(id)
|
|
495
558
|
result[id] = {
|
|
496
|
-
'id':
|
|
559
|
+
'id': id,
|
|
497
560
|
'rddl_op': ' '.join(expr.etype),
|
|
498
561
|
'init_value': value
|
|
499
562
|
}
|
|
500
563
|
return result
|
|
501
564
|
|
|
565
|
+
def overriden_ops_info(self) -> Dict[str, Dict[str, List[int]]]:
|
|
566
|
+
'''Returns a dictionary of operations overriden by another class.'''
|
|
567
|
+
result = {}
|
|
568
|
+
for (id, class_) in self.model_aux['overriden'].items():
|
|
569
|
+
expr = self.traced.lookup(id)
|
|
570
|
+
rddl_op = ' '.join(expr.etype)
|
|
571
|
+
result.setdefault(class_, {}).setdefault(rddl_op, []).append(id)
|
|
572
|
+
return result
|
|
573
|
+
|
|
502
574
|
@staticmethod
|
|
503
575
|
def _check_valid_op(expr, valid_ops):
|
|
504
576
|
etype, op = expr.etype
|
|
505
577
|
if op not in valid_ops:
|
|
506
|
-
valid_op_str = ','.join(valid_ops
|
|
578
|
+
valid_op_str = ','.join(valid_ops)
|
|
507
579
|
raise RDDLNotImplementedError(
|
|
508
580
|
f'{etype} operator {op} is not supported: '
|
|
509
|
-
f'must be
|
|
581
|
+
f'must be one of {valid_op_str}.\n' + print_stack_trace(expr))
|
|
510
582
|
|
|
511
583
|
@staticmethod
|
|
512
584
|
def _check_num_args(expr, required_args):
|
|
@@ -516,6 +588,15 @@ class JaxRDDLCompiler:
|
|
|
516
588
|
raise RDDLInvalidNumberOfArgumentsError(
|
|
517
589
|
f'{etype} operator {op} requires {required_args} arguments, '
|
|
518
590
|
f'got {actual_args}.\n' + print_stack_trace(expr))
|
|
591
|
+
|
|
592
|
+
@staticmethod
|
|
593
|
+
def _check_num_args_min(expr, required_args):
|
|
594
|
+
actual_args = len(expr.args)
|
|
595
|
+
if actual_args < required_args:
|
|
596
|
+
etype, op = expr.etype
|
|
597
|
+
raise RDDLInvalidNumberOfArgumentsError(
|
|
598
|
+
f'{etype} operator {op} requires at least {required_args} arguments, '
|
|
599
|
+
f'got {actual_args}.\n' + print_stack_trace(expr))
|
|
519
600
|
|
|
520
601
|
ERROR_CODES = {
|
|
521
602
|
'NORMAL': 0,
|
|
@@ -580,8 +661,7 @@ class JaxRDDLCompiler:
|
|
|
580
661
|
decomposes it into individual error codes.
|
|
581
662
|
'''
|
|
582
663
|
binary = reversed(bin(error)[2:])
|
|
583
|
-
|
|
584
|
-
return errors
|
|
664
|
+
return [i for (i, c) in enumerate(binary) if c == '1']
|
|
585
665
|
|
|
586
666
|
@staticmethod
|
|
587
667
|
def get_error_messages(error: int) -> List[str]:
|
|
@@ -589,63 +669,59 @@ class JaxRDDLCompiler:
|
|
|
589
669
|
decomposes it into error strings.
|
|
590
670
|
'''
|
|
591
671
|
codes = JaxRDDLCompiler.get_error_codes(error)
|
|
592
|
-
|
|
593
|
-
return messages
|
|
672
|
+
return [JaxRDDLCompiler.INVERSE_ERROR_CODES[i] for i in codes]
|
|
594
673
|
|
|
595
674
|
# ===========================================================================
|
|
596
675
|
# expression compilation
|
|
597
676
|
# ===========================================================================
|
|
598
677
|
|
|
599
|
-
def _jax(self, expr,
|
|
678
|
+
def _jax(self, expr, aux, dtype=None):
|
|
600
679
|
etype, _ = expr.etype
|
|
601
680
|
if etype == 'constant':
|
|
602
|
-
jax_expr = self._jax_constant(expr,
|
|
681
|
+
jax_expr = self._jax_constant(expr, aux)
|
|
603
682
|
elif etype == 'pvar':
|
|
604
|
-
jax_expr = self._jax_pvar(expr,
|
|
683
|
+
jax_expr = self._jax_pvar(expr, aux)
|
|
605
684
|
elif etype == 'arithmetic':
|
|
606
|
-
jax_expr = self._jax_arithmetic(expr,
|
|
685
|
+
jax_expr = self._jax_arithmetic(expr, aux)
|
|
607
686
|
elif etype == 'relational':
|
|
608
|
-
jax_expr = self._jax_relational(expr,
|
|
687
|
+
jax_expr = self._jax_relational(expr, aux)
|
|
609
688
|
elif etype == 'boolean':
|
|
610
|
-
jax_expr = self._jax_logical(expr,
|
|
689
|
+
jax_expr = self._jax_logical(expr, aux)
|
|
611
690
|
elif etype == 'aggregation':
|
|
612
|
-
jax_expr = self._jax_aggregation(expr,
|
|
691
|
+
jax_expr = self._jax_aggregation(expr, aux)
|
|
613
692
|
elif etype == 'func':
|
|
614
|
-
jax_expr = self.
|
|
693
|
+
jax_expr = self._jax_function(expr, aux)
|
|
615
694
|
elif etype == 'pyfunc':
|
|
616
|
-
jax_expr = self._jax_pyfunc(expr,
|
|
695
|
+
jax_expr = self._jax_pyfunc(expr, aux)
|
|
617
696
|
elif etype == 'control':
|
|
618
|
-
jax_expr = self._jax_control(expr,
|
|
697
|
+
jax_expr = self._jax_control(expr, aux)
|
|
619
698
|
elif etype == 'randomvar':
|
|
620
|
-
jax_expr = self._jax_random(expr,
|
|
699
|
+
jax_expr = self._jax_random(expr, aux)
|
|
621
700
|
elif etype == 'randomvector':
|
|
622
|
-
jax_expr = self._jax_random_vector(expr,
|
|
701
|
+
jax_expr = self._jax_random_vector(expr, aux)
|
|
623
702
|
elif etype == 'matrix':
|
|
624
|
-
jax_expr = self._jax_matrix(expr,
|
|
703
|
+
jax_expr = self._jax_matrix(expr, aux)
|
|
625
704
|
else:
|
|
626
705
|
raise RDDLNotImplementedError(
|
|
627
|
-
f'
|
|
628
|
-
print_stack_trace(expr))
|
|
706
|
+
f'Expression type {expr} is not supported.\n' + print_stack_trace(expr))
|
|
629
707
|
|
|
630
708
|
# force type cast of tensor as required by caller
|
|
631
709
|
if dtype is not None:
|
|
632
710
|
jax_expr = self._jax_cast(jax_expr, dtype)
|
|
633
|
-
|
|
634
711
|
return jax_expr
|
|
635
712
|
|
|
636
713
|
def _jax_cast(self, jax_expr, dtype):
|
|
637
714
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
638
|
-
|
|
639
|
-
def _jax_wrapped_cast(
|
|
640
|
-
val, key, err, params = jax_expr(
|
|
715
|
+
|
|
716
|
+
def _jax_wrapped_cast(fls, nfls, params, key):
|
|
717
|
+
val, key, err, params = jax_expr(fls, nfls, params, key)
|
|
641
718
|
sample = jnp.asarray(val, dtype=dtype)
|
|
642
719
|
invalid_cast = jnp.logical_and(
|
|
643
720
|
jnp.logical_not(jnp.can_cast(val, dtype)),
|
|
644
721
|
jnp.any(sample != val)
|
|
645
722
|
)
|
|
646
|
-
err
|
|
723
|
+
err = err | (invalid_cast * ERR)
|
|
647
724
|
return sample, key, err, params
|
|
648
|
-
|
|
649
725
|
return _jax_wrapped_cast
|
|
650
726
|
|
|
651
727
|
def _fix_dtype(self, value):
|
|
@@ -654,34 +730,33 @@ class JaxRDDLCompiler:
|
|
|
654
730
|
return self.INT
|
|
655
731
|
elif jnp.issubdtype(dtype, jnp.floating):
|
|
656
732
|
return self.REAL
|
|
657
|
-
elif jnp.issubdtype(dtype, jnp.bool_)
|
|
658
|
-
return
|
|
733
|
+
elif jnp.issubdtype(dtype, jnp.bool_):
|
|
734
|
+
return jnp.bool_
|
|
659
735
|
else:
|
|
660
|
-
raise TypeError(f'
|
|
736
|
+
raise TypeError(f'dtype {dtype} of {value} is not valid.')
|
|
661
737
|
|
|
662
738
|
# ===========================================================================
|
|
663
739
|
# leaves
|
|
664
740
|
# ===========================================================================
|
|
665
741
|
|
|
666
|
-
def _jax_constant(self, expr,
|
|
742
|
+
def _jax_constant(self, expr, aux):
|
|
667
743
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
668
744
|
cached_value = self.traced.cached_sim_info(expr)
|
|
745
|
+
dtype = self._fix_dtype(cached_value)
|
|
669
746
|
|
|
670
|
-
def _jax_wrapped_constant(
|
|
671
|
-
sample = jnp.asarray(cached_value, dtype=
|
|
747
|
+
def _jax_wrapped_constant(fls, nfls, params, key):
|
|
748
|
+
sample = jnp.asarray(cached_value, dtype=dtype)
|
|
672
749
|
return sample, key, NORMAL, params
|
|
673
|
-
|
|
674
750
|
return _jax_wrapped_constant
|
|
675
751
|
|
|
676
|
-
def _jax_pvar_slice(self,
|
|
752
|
+
def _jax_pvar_slice(self, slice_):
|
|
677
753
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
678
754
|
|
|
679
|
-
def _jax_wrapped_pvar_slice(
|
|
680
|
-
return
|
|
681
|
-
|
|
755
|
+
def _jax_wrapped_pvar_slice(fls, nfls, params, key):
|
|
756
|
+
return slice_, key, NORMAL, params
|
|
682
757
|
return _jax_wrapped_pvar_slice
|
|
683
758
|
|
|
684
|
-
def _jax_pvar(self, expr,
|
|
759
|
+
def _jax_pvar(self, expr, aux):
|
|
685
760
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
686
761
|
var, pvars = expr.args
|
|
687
762
|
is_value, cached_info = self.traced.cached_sim_info(expr)
|
|
@@ -690,21 +765,19 @@ class JaxRDDLCompiler:
|
|
|
690
765
|
# boundary case: domain object is converted to canonical integer index
|
|
691
766
|
if is_value:
|
|
692
767
|
cached_value = cached_info
|
|
693
|
-
|
|
694
|
-
def _jax_wrapped_object(x, params, key):
|
|
695
|
-
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
696
|
-
return sample, key, NORMAL, params
|
|
768
|
+
dtype = self._fix_dtype(cached_value)
|
|
697
769
|
|
|
770
|
+
def _jax_wrapped_object(fls, nfls, params, key):
|
|
771
|
+
sample = jnp.asarray(cached_value, dtype=dtype)
|
|
772
|
+
return sample, key, NORMAL, params
|
|
698
773
|
return _jax_wrapped_object
|
|
699
774
|
|
|
700
775
|
# boundary case: no shape information (e.g. scalar pvar)
|
|
701
776
|
elif cached_info is None:
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
value = x[var]
|
|
777
|
+
def _jax_wrapped_pvar_scalar(fls, nfls, params, key):
|
|
778
|
+
value = fls[var] if var in fls else nfls[var]
|
|
705
779
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
706
780
|
return sample, key, NORMAL, params
|
|
707
|
-
|
|
708
781
|
return _jax_wrapped_pvar_scalar
|
|
709
782
|
|
|
710
783
|
# must slice and/or reshape value tensor to match free variables
|
|
@@ -713,34 +786,29 @@ class JaxRDDLCompiler:
|
|
|
713
786
|
|
|
714
787
|
# compile nested expressions
|
|
715
788
|
if slices and op_code == RDDLObjectsTracer.NUMPY_OP_CODE.NESTED_SLICE:
|
|
789
|
+
jax_nested_expr = [
|
|
790
|
+
(self._jax(arg, aux) if slice_ is None else self._jax_pvar_slice(slice_))
|
|
791
|
+
for (arg, slice_) in zip(pvars, slices)
|
|
792
|
+
]
|
|
716
793
|
|
|
717
|
-
|
|
718
|
-
if _slice is None
|
|
719
|
-
else self._jax_pvar_slice(_slice))
|
|
720
|
-
for (arg, _slice) in zip(pvars, slices)]
|
|
721
|
-
|
|
722
|
-
def _jax_wrapped_pvar_tensor_nested(x, params, key):
|
|
794
|
+
def _jax_wrapped_pvar_tensor_nested(fls, nfls, params, key):
|
|
723
795
|
error = NORMAL
|
|
724
|
-
value =
|
|
796
|
+
value = fls[var] if var in fls else nfls[var]
|
|
725
797
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
726
|
-
new_slices = [
|
|
727
|
-
for
|
|
728
|
-
new_slice, key, err, params = jax_expr(
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
new_slices = tuple(new_slices)
|
|
734
|
-
sample = sample[new_slices]
|
|
798
|
+
new_slices = []
|
|
799
|
+
for jax_expr in jax_nested_expr:
|
|
800
|
+
new_slice, key, err, params = jax_expr(fls, nfls, params, key)
|
|
801
|
+
new_slice = jnp.asarray(new_slice, dtype=self.INT)
|
|
802
|
+
new_slices.append(new_slice)
|
|
803
|
+
error = error | err
|
|
804
|
+
sample = sample[tuple(new_slices)]
|
|
735
805
|
return sample, key, error, params
|
|
736
|
-
|
|
737
806
|
return _jax_wrapped_pvar_tensor_nested
|
|
738
807
|
|
|
739
808
|
# tensor variable but no nesting
|
|
740
809
|
else:
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
value = x[var]
|
|
810
|
+
def _jax_wrapped_pvar_tensor_non_nested(fls, nfls, params, key):
|
|
811
|
+
value = fls[var] if var in fls else nfls[var]
|
|
744
812
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
745
813
|
if slices:
|
|
746
814
|
sample = sample[slices]
|
|
@@ -752,190 +820,408 @@ class JaxRDDLCompiler:
|
|
|
752
820
|
elif op_code == RDDLObjectsTracer.NUMPY_OP_CODE.TRANSPOSE:
|
|
753
821
|
sample = jnp.transpose(sample, axes=op_args)
|
|
754
822
|
return sample, key, NORMAL, params
|
|
755
|
-
|
|
756
823
|
return _jax_wrapped_pvar_tensor_non_nested
|
|
757
824
|
|
|
758
825
|
# ===========================================================================
|
|
759
|
-
#
|
|
826
|
+
# boilerplate helper functions
|
|
760
827
|
# ===========================================================================
|
|
761
828
|
|
|
762
829
|
def _jax_unary(self, jax_expr, jax_op, at_least_int=False, check_dtype=None):
|
|
763
830
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
764
831
|
|
|
765
|
-
def _jax_wrapped_unary_op(
|
|
766
|
-
sample, key, err, params = jax_expr(
|
|
832
|
+
def _jax_wrapped_unary_op(fls, nfls, params, key):
|
|
833
|
+
sample, key, err, params = jax_expr(fls, nfls, params, key)
|
|
767
834
|
if at_least_int:
|
|
768
835
|
sample = self.ONE * sample
|
|
769
|
-
sample, params = jax_op(sample, params)
|
|
770
836
|
if check_dtype is not None:
|
|
771
837
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample, check_dtype))
|
|
772
|
-
err
|
|
838
|
+
err = err | (invalid_cast * ERR)
|
|
839
|
+
sample = jax_op(sample)
|
|
773
840
|
return sample, key, err, params
|
|
774
|
-
|
|
775
841
|
return _jax_wrapped_unary_op
|
|
776
842
|
|
|
777
843
|
def _jax_binary(self, jax_lhs, jax_rhs, jax_op, at_least_int=False, check_dtype=None):
|
|
778
844
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
779
845
|
|
|
780
|
-
def _jax_wrapped_binary_op(
|
|
781
|
-
sample1, key, err1, params = jax_lhs(
|
|
782
|
-
sample2, key, err2, params = jax_rhs(
|
|
846
|
+
def _jax_wrapped_binary_op(fls, nfls, params, key):
|
|
847
|
+
sample1, key, err1, params = jax_lhs(fls, nfls, params, key)
|
|
848
|
+
sample2, key, err2, params = jax_rhs(fls, nfls, params, key)
|
|
783
849
|
if at_least_int:
|
|
784
850
|
sample1 = self.ONE * sample1
|
|
785
851
|
sample2 = self.ONE * sample2
|
|
786
|
-
sample
|
|
852
|
+
sample = jax_op(sample1, sample2)
|
|
787
853
|
err = err1 | err2
|
|
788
854
|
if check_dtype is not None:
|
|
789
855
|
invalid_cast = jnp.logical_not(jnp.logical_and(
|
|
790
856
|
jnp.can_cast(sample1, check_dtype),
|
|
791
857
|
jnp.can_cast(sample2, check_dtype))
|
|
792
858
|
)
|
|
793
|
-
err
|
|
859
|
+
err = err | (invalid_cast * ERR)
|
|
794
860
|
return sample, key, err, params
|
|
795
|
-
|
|
796
861
|
return _jax_wrapped_binary_op
|
|
797
862
|
|
|
798
|
-
def
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
valid_ops = self.OPS['arithmetic']
|
|
807
|
-
negative_op = self.OPS['negative']
|
|
808
|
-
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
809
|
-
|
|
810
|
-
# recursively compile arguments
|
|
811
|
-
args = expr.args
|
|
812
|
-
n = len(args)
|
|
813
|
-
if n == 1 and op == '-':
|
|
814
|
-
arg, = args
|
|
815
|
-
jax_expr = self._jax(arg, init_params)
|
|
816
|
-
jax_op = negative_op(expr.id, init_params)
|
|
817
|
-
return self._jax_unary(jax_expr, jax_op, at_least_int=True)
|
|
818
|
-
|
|
819
|
-
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
820
|
-
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
821
|
-
result = jax_exprs[0]
|
|
822
|
-
for (i, jax_rhs) in enumerate(jax_exprs[1:]):
|
|
823
|
-
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
824
|
-
result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
|
|
825
|
-
return result
|
|
826
|
-
|
|
863
|
+
def _jax_unary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
|
|
864
|
+
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
865
|
+
arg, = expr.args
|
|
866
|
+
jax_expr = self._jax(arg, aux)
|
|
867
|
+
return self._jax_unary(
|
|
868
|
+
jax_expr, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
|
|
869
|
+
|
|
870
|
+
def _jax_binary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
|
|
827
871
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
872
|
+
lhs, rhs = expr.args
|
|
873
|
+
jax_lhs = self._jax(lhs, aux)
|
|
874
|
+
jax_rhs = self._jax(rhs, aux)
|
|
875
|
+
return self._jax_binary(
|
|
876
|
+
jax_lhs, jax_rhs, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
|
|
828
877
|
|
|
829
|
-
def
|
|
878
|
+
def _jax_nary_helper(self, expr, aux, jax_op, at_least_int=False, check_dtype=None):
|
|
879
|
+
JaxRDDLCompiler._check_num_args_min(expr, 2)
|
|
880
|
+
args = expr.args
|
|
881
|
+
jax_exprs = [self._jax(arg, aux) for arg in args]
|
|
882
|
+
result = jax_exprs[0]
|
|
883
|
+
for jax_rhs in jax_exprs[1:]:
|
|
884
|
+
result = self._jax_binary(
|
|
885
|
+
result, jax_rhs, jax_op, at_least_int=at_least_int, check_dtype=check_dtype)
|
|
886
|
+
return result
|
|
887
|
+
|
|
888
|
+
# ===========================================================================
|
|
889
|
+
# arithmetic
|
|
890
|
+
# ===========================================================================
|
|
891
|
+
|
|
892
|
+
def _jax_arithmetic(self, expr, aux):
|
|
893
|
+
JaxRDDLCompiler._check_valid_op(expr, {'-', '+', '*', '/'})
|
|
830
894
|
_, op = expr.etype
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
895
|
+
if op == '-':
|
|
896
|
+
if len(expr.args) == 1:
|
|
897
|
+
return self._jax_negate(expr, aux)
|
|
898
|
+
else:
|
|
899
|
+
return self._jax_subtract(expr, aux)
|
|
900
|
+
elif op == '/':
|
|
901
|
+
return self._jax_divide(expr, aux)
|
|
902
|
+
elif op == '+':
|
|
903
|
+
return self._jax_add(expr, aux)
|
|
904
|
+
elif op == '*':
|
|
905
|
+
return self._jax_multiply(expr, aux)
|
|
906
|
+
|
|
907
|
+
def _jax_negate(self, expr, aux):
|
|
908
|
+
return self._jax_unary_helper(expr, aux, jnp.negative, at_least_int=True)
|
|
909
|
+
|
|
910
|
+
def _jax_add(self, expr, aux):
|
|
911
|
+
return self._jax_nary_helper(expr, aux, jnp.add, at_least_int=True)
|
|
912
|
+
|
|
913
|
+
def _jax_subtract(self, expr, aux):
|
|
914
|
+
return self._jax_binary_helper(expr, aux, jnp.subtract, at_least_int=True)
|
|
915
|
+
|
|
916
|
+
def _jax_multiply(self, expr, aux):
|
|
917
|
+
return self._jax_nary_helper(expr, aux, jnp.multiply, at_least_int=True)
|
|
918
|
+
|
|
919
|
+
def _jax_divide(self, expr, aux):
|
|
920
|
+
return self._jax_binary_helper(expr, aux, jnp.divide, at_least_int=True)
|
|
921
|
+
|
|
922
|
+
# ===========================================================================
|
|
923
|
+
# relational
|
|
924
|
+
# ===========================================================================
|
|
925
|
+
|
|
926
|
+
def _jax_relational(self, expr, aux):
|
|
927
|
+
JaxRDDLCompiler._check_valid_op(expr, {'>=', '<=', '>', '<', '==', '~='})
|
|
848
928
|
_, op = expr.etype
|
|
929
|
+
if op == '>=':
|
|
930
|
+
return self._jax_greater_equal(expr, aux)
|
|
931
|
+
elif op == '<=':
|
|
932
|
+
return self._jax_less_equal(expr, aux)
|
|
933
|
+
elif op == '>':
|
|
934
|
+
return self._jax_greater(expr, aux)
|
|
935
|
+
elif op == '<':
|
|
936
|
+
return self._jax_less(expr, aux)
|
|
937
|
+
elif op == '==':
|
|
938
|
+
return self._jax_equal(expr, aux)
|
|
939
|
+
elif op == '~=':
|
|
940
|
+
return self._jax_not_equal(expr, aux)
|
|
941
|
+
|
|
942
|
+
def _jax_greater_equal(self, expr, aux):
|
|
943
|
+
return self._jax_binary_helper(expr, aux, jnp.greater_equal, at_least_int=True)
|
|
849
944
|
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
valid_ops = self.OPS['logical']
|
|
856
|
-
logical_not_op = self.OPS['logical_not']
|
|
857
|
-
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
858
|
-
|
|
859
|
-
# recursively compile arguments
|
|
860
|
-
args = expr.args
|
|
861
|
-
n = len(args)
|
|
862
|
-
if n == 1 and op == '~':
|
|
863
|
-
arg, = args
|
|
864
|
-
jax_expr = self._jax(arg, init_params)
|
|
865
|
-
jax_op = logical_not_op(expr.id, init_params)
|
|
866
|
-
return self._jax_unary(jax_expr, jax_op, check_dtype=bool)
|
|
867
|
-
|
|
868
|
-
elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
|
|
869
|
-
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
870
|
-
result = jax_exprs[0]
|
|
871
|
-
for i, jax_rhs in enumerate(jax_exprs[1:]):
|
|
872
|
-
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
873
|
-
result = self._jax_binary(result, jax_rhs, jax_op, check_dtype=bool)
|
|
874
|
-
return result
|
|
875
|
-
|
|
876
|
-
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
945
|
+
def _jax_less_equal(self, expr, aux):
|
|
946
|
+
return self._jax_binary_helper(expr, aux, jnp.less_equal, at_least_int=True)
|
|
947
|
+
|
|
948
|
+
def _jax_greater(self, expr, aux):
|
|
949
|
+
return self._jax_binary_helper(expr, aux, jnp.greater, at_least_int=True)
|
|
877
950
|
|
|
878
|
-
def
|
|
879
|
-
|
|
951
|
+
def _jax_less(self, expr, aux):
|
|
952
|
+
return self._jax_binary_helper(expr, aux, jnp.less, at_least_int=True)
|
|
953
|
+
|
|
954
|
+
def _jax_equal(self, expr, aux):
|
|
955
|
+
return self._jax_binary_helper(expr, aux, jnp.equal, at_least_int=True)
|
|
956
|
+
|
|
957
|
+
def _jax_not_equal(self, expr, aux):
|
|
958
|
+
return self._jax_binary_helper(expr, aux, jnp.not_equal, at_least_int=True)
|
|
959
|
+
|
|
960
|
+
# ===========================================================================
|
|
961
|
+
# logical
|
|
962
|
+
# ===========================================================================
|
|
963
|
+
|
|
964
|
+
def _jax_logical(self, expr, aux):
|
|
965
|
+
JaxRDDLCompiler._check_valid_op(expr, {'^', '&', '|', '~', '=>', '<=>'})
|
|
880
966
|
_, op = expr.etype
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
valid_ops = self.EXACT_OPS['aggregation']
|
|
885
|
-
else:
|
|
886
|
-
valid_ops = self.OPS['aggregation']
|
|
887
|
-
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
888
|
-
is_floating = op not in self.AGGREGATION_BOOL
|
|
889
|
-
|
|
890
|
-
# recursively compile arguments
|
|
891
|
-
* _, arg = expr.args
|
|
892
|
-
_, axes = self.traced.cached_sim_info(expr)
|
|
893
|
-
jax_expr = self._jax(arg, init_params)
|
|
894
|
-
jax_op = valid_ops[op](expr.id, init_params)
|
|
895
|
-
|
|
896
|
-
def _jax_wrapped_aggregation(x, params, key):
|
|
897
|
-
sample, key, err, params = jax_expr(x, params, key)
|
|
898
|
-
if is_floating:
|
|
899
|
-
sample = self.ONE * sample
|
|
967
|
+
if op == '~':
|
|
968
|
+
if len(expr.args) == 1:
|
|
969
|
+
return self._jax_not(expr, aux)
|
|
900
970
|
else:
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
971
|
+
return self._jax_xor(expr, aux)
|
|
972
|
+
elif op == '^' or op == '&':
|
|
973
|
+
return self._jax_and(expr, aux)
|
|
974
|
+
elif op == '|':
|
|
975
|
+
return self._jax_or(expr, aux)
|
|
976
|
+
elif op == '=>':
|
|
977
|
+
return self._jax_implies(expr, aux)
|
|
978
|
+
elif op == '<=>':
|
|
979
|
+
return self._jax_equiv(expr, aux)
|
|
980
|
+
|
|
981
|
+
def _jax_not(self, expr, aux):
|
|
982
|
+
return self._jax_unary_helper(expr, aux, jnp.logical_not, check_dtype=jnp.bool_)
|
|
983
|
+
|
|
984
|
+
def _jax_and(self, expr, aux):
|
|
985
|
+
return self._jax_nary_helper(expr, aux, jnp.logical_and, check_dtype=jnp.bool_)
|
|
986
|
+
|
|
987
|
+
def _jax_or(self, expr, aux):
|
|
988
|
+
return self._jax_nary_helper(expr, aux, jnp.logical_or, check_dtype=jnp.bool_)
|
|
989
|
+
|
|
990
|
+
def _jax_xor(self, expr, aux):
|
|
991
|
+
return self._jax_binary_helper(expr, aux, jnp.logical_xor, check_dtype=jnp.bool_)
|
|
992
|
+
|
|
993
|
+
def _jax_implies(self, expr, aux):
|
|
994
|
+
def implies_op(x, y):
|
|
995
|
+
return jnp.logical_or(jnp.logical_not(x), y)
|
|
996
|
+
return self._jax_binary_helper(expr, aux, implies_op, check_dtype=jnp.bool_)
|
|
997
|
+
|
|
998
|
+
def _jax_equiv(self, expr, aux):
|
|
999
|
+
return self._jax_binary_helper(expr, aux, jnp.equal, check_dtype=jnp.bool_)
|
|
1000
|
+
|
|
1001
|
+
# ===========================================================================
|
|
1002
|
+
# aggregation
|
|
1003
|
+
# ===========================================================================
|
|
1004
|
+
|
|
1005
|
+
def _jax_aggregation(self, expr, aux):
|
|
1006
|
+
JaxRDDLCompiler._check_valid_op(expr, {'sum', 'avg', 'prod', 'minimum', 'maximum',
|
|
1007
|
+
'forall', 'exists', 'argmin', 'argmax'})
|
|
909
1008
|
_, op = expr.etype
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
return self.
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
1009
|
+
if op == 'sum':
|
|
1010
|
+
return self._jax_sum(expr, aux)
|
|
1011
|
+
elif op == 'avg':
|
|
1012
|
+
return self._jax_avg(expr, aux)
|
|
1013
|
+
elif op == 'prod':
|
|
1014
|
+
return self._jax_prod(expr, aux)
|
|
1015
|
+
elif op == 'minimum':
|
|
1016
|
+
return self._jax_minimum(expr, aux)
|
|
1017
|
+
elif op == 'maximum':
|
|
1018
|
+
return self._jax_maximum(expr, aux)
|
|
1019
|
+
elif op == 'forall':
|
|
1020
|
+
return self._jax_forall(expr, aux)
|
|
1021
|
+
elif op == 'exists':
|
|
1022
|
+
return self._jax_exists(expr, aux)
|
|
1023
|
+
elif op == 'argmin':
|
|
1024
|
+
return self._jax_argmin(expr, aux)
|
|
1025
|
+
elif op == 'argmax':
|
|
1026
|
+
return self._jax_argmax(expr, aux)
|
|
1027
|
+
|
|
1028
|
+
def _jax_aggregation_helper(self, expr, aux, jax_op, is_bool=False):
|
|
1029
|
+
arg = expr.args[-1]
|
|
1030
|
+
_, axes = self.traced.cached_sim_info(expr)
|
|
1031
|
+
jax_expr = self._jax(arg, aux)
|
|
1032
|
+
return self._jax_unary(
|
|
1033
|
+
jax_expr,
|
|
1034
|
+
jax_op=partial(jax_op, axis=axes),
|
|
1035
|
+
at_least_int=not is_bool,
|
|
1036
|
+
check_dtype=jnp.bool_ if is_bool else None
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
def _jax_sum(self, expr, aux):
|
|
1040
|
+
return self._jax_aggregation_helper(expr, aux, jnp.sum)
|
|
1041
|
+
|
|
1042
|
+
def _jax_avg(self, expr, aux):
|
|
1043
|
+
return self._jax_aggregation_helper(expr, aux, jnp.mean)
|
|
1044
|
+
|
|
1045
|
+
def _jax_prod(self, expr, aux):
|
|
1046
|
+
return self._jax_aggregation_helper(expr, aux, jnp.prod)
|
|
1047
|
+
|
|
1048
|
+
def _jax_minimum(self, expr, aux):
|
|
1049
|
+
return self._jax_aggregation_helper(expr, aux, jnp.min)
|
|
1050
|
+
|
|
1051
|
+
def _jax_maximum(self, expr, aux):
|
|
1052
|
+
return self._jax_aggregation_helper(expr, aux, jnp.max)
|
|
1053
|
+
|
|
1054
|
+
def _jax_forall(self, expr, aux):
|
|
1055
|
+
return self._jax_aggregation_helper(expr, aux, jnp.all, is_bool=True)
|
|
1056
|
+
|
|
1057
|
+
def _jax_exists(self, expr, aux):
|
|
1058
|
+
return self._jax_aggregation_helper(expr, aux, jnp.any, is_bool=True)
|
|
1059
|
+
|
|
1060
|
+
def _jax_argmin(self, expr, aux):
|
|
1061
|
+
return self._jax_aggregation_helper(expr, aux, jnp.argmin)
|
|
1062
|
+
|
|
1063
|
+
def _jax_argmax(self, expr, aux):
|
|
1064
|
+
return self._jax_aggregation_helper(expr, aux, jnp.argmax)
|
|
1065
|
+
|
|
1066
|
+
# ===========================================================================
|
|
1067
|
+
# function
|
|
1068
|
+
# ===========================================================================
|
|
1069
|
+
|
|
1070
|
+
def _jax_function(self, expr, aux):
|
|
1071
|
+
JaxRDDLCompiler._check_valid_op(expr, {'abs', 'sgn', 'round', 'floor', 'ceil',
|
|
1072
|
+
'cos', 'sin', 'tan', 'acos', 'asin', 'atan',
|
|
1073
|
+
'cosh', 'sinh', 'tanh', 'exp', 'ln', 'sqrt',
|
|
1074
|
+
'lngamma', 'gamma',
|
|
1075
|
+
'div', 'mod', 'fmod', 'min', 'max',
|
|
1076
|
+
'pow', 'log', 'hypot'})
|
|
1077
|
+
_, op = expr.etype
|
|
1078
|
+
|
|
1079
|
+
# unary functions
|
|
1080
|
+
if op == 'abs':
|
|
1081
|
+
return self._jax_abs(expr, aux)
|
|
1082
|
+
elif op == 'sgn':
|
|
1083
|
+
return self._jax_sgn(expr, aux)
|
|
1084
|
+
elif op == 'round':
|
|
1085
|
+
return self._jax_round(expr, aux)
|
|
1086
|
+
elif op == 'floor':
|
|
1087
|
+
return self._jax_floor(expr, aux)
|
|
1088
|
+
elif op == 'ceil':
|
|
1089
|
+
return self._jax_ceil(expr, aux)
|
|
1090
|
+
elif op == 'cos':
|
|
1091
|
+
return self._jax_cos(expr, aux)
|
|
1092
|
+
elif op == 'sin':
|
|
1093
|
+
return self._jax_sin(expr, aux)
|
|
1094
|
+
elif op == 'tan':
|
|
1095
|
+
return self._jax_tan(expr, aux)
|
|
1096
|
+
elif op == 'acos':
|
|
1097
|
+
return self._jax_acos(expr, aux)
|
|
1098
|
+
elif op == 'asin':
|
|
1099
|
+
return self._jax_asin(expr, aux)
|
|
1100
|
+
elif op == 'atan':
|
|
1101
|
+
return self._jax_atan(expr, aux)
|
|
1102
|
+
elif op == 'cosh':
|
|
1103
|
+
return self._jax_cosh(expr, aux)
|
|
1104
|
+
elif op == 'sinh':
|
|
1105
|
+
return self._jax_sinh(expr, aux)
|
|
1106
|
+
elif op == 'tanh':
|
|
1107
|
+
return self._jax_tanh(expr, aux)
|
|
1108
|
+
elif op == 'exp':
|
|
1109
|
+
return self._jax_exp(expr, aux)
|
|
1110
|
+
elif op == 'ln':
|
|
1111
|
+
return self._jax_ln(expr, aux)
|
|
1112
|
+
elif op == 'sqrt':
|
|
1113
|
+
return self._jax_sqrt(expr, aux)
|
|
1114
|
+
elif op == 'lngamma':
|
|
1115
|
+
return self._jax_lngamma(expr, aux)
|
|
1116
|
+
elif op == 'gamma':
|
|
1117
|
+
return self._jax_gamma(expr, aux)
|
|
1118
|
+
|
|
1119
|
+
# binary functions
|
|
1120
|
+
elif op == 'div':
|
|
1121
|
+
return self._jax_div(expr, aux)
|
|
1122
|
+
elif op == 'mod':
|
|
1123
|
+
return self._jax_mod(expr, aux)
|
|
1124
|
+
elif op == 'fmod':
|
|
1125
|
+
return self._jax_fmod(expr, aux)
|
|
1126
|
+
elif op == 'min':
|
|
1127
|
+
return self._jax_min(expr, aux)
|
|
1128
|
+
elif op == 'max':
|
|
1129
|
+
return self._jax_max(expr, aux)
|
|
1130
|
+
elif op == 'pow':
|
|
1131
|
+
return self._jax_pow(expr, aux)
|
|
1132
|
+
elif op == 'log':
|
|
1133
|
+
return self._jax_log(expr, aux)
|
|
1134
|
+
elif op == 'hypot':
|
|
1135
|
+
return self._jax_hypot(expr, aux)
|
|
1136
|
+
|
|
1137
|
+
def _jax_abs(self, expr, aux):
|
|
1138
|
+
return self._jax_unary_helper(expr, aux, jnp.abs, at_least_int=True)
|
|
1139
|
+
|
|
1140
|
+
def _jax_sgn(self, expr, aux):
|
|
1141
|
+
return self._jax_unary_helper(expr, aux, jnp.sign, at_least_int=True)
|
|
1142
|
+
|
|
1143
|
+
def _jax_round(self, expr, aux):
|
|
1144
|
+
return self._jax_unary_helper(expr, aux, jnp.round, at_least_int=True)
|
|
1145
|
+
|
|
1146
|
+
def _jax_floor(self, expr, aux):
|
|
1147
|
+
return self._jax_unary_helper(expr, aux, jnp.floor, at_least_int=True)
|
|
1148
|
+
|
|
1149
|
+
def _jax_ceil(self, expr, aux):
|
|
1150
|
+
return self._jax_unary_helper(expr, aux, jnp.ceil, at_least_int=True)
|
|
1151
|
+
|
|
1152
|
+
def _jax_cos(self, expr, aux):
|
|
1153
|
+
return self._jax_unary_helper(expr, aux, jnp.cos, at_least_int=True)
|
|
1154
|
+
|
|
1155
|
+
def _jax_sin(self, expr, aux):
|
|
1156
|
+
return self._jax_unary_helper(expr, aux, jnp.sin, at_least_int=True)
|
|
1157
|
+
|
|
1158
|
+
def _jax_tan(self, expr, aux):
|
|
1159
|
+
return self._jax_unary_helper(expr, aux, jnp.tan, at_least_int=True)
|
|
1160
|
+
|
|
1161
|
+
def _jax_acos(self, expr, aux):
|
|
1162
|
+
return self._jax_unary_helper(expr, aux, jnp.arccos, at_least_int=True)
|
|
1163
|
+
|
|
1164
|
+
def _jax_asin(self, expr, aux):
|
|
1165
|
+
return self._jax_unary_helper(expr, aux, jnp.arcsin, at_least_int=True)
|
|
1166
|
+
|
|
1167
|
+
def _jax_atan(self, expr, aux):
|
|
1168
|
+
return self._jax_unary_helper(expr, aux, jnp.arctan, at_least_int=True)
|
|
1169
|
+
|
|
1170
|
+
def _jax_cosh(self, expr, aux):
|
|
1171
|
+
return self._jax_unary_helper(expr, aux, jnp.cosh, at_least_int=True)
|
|
1172
|
+
|
|
1173
|
+
def _jax_sinh(self, expr, aux):
|
|
1174
|
+
return self._jax_unary_helper(expr, aux, jnp.sinh, at_least_int=True)
|
|
1175
|
+
|
|
1176
|
+
def _jax_tanh(self, expr, aux):
|
|
1177
|
+
return self._jax_unary_helper(expr, aux, jnp.tanh, at_least_int=True)
|
|
1178
|
+
|
|
1179
|
+
def _jax_exp(self, expr, aux):
|
|
1180
|
+
return self._jax_unary_helper(expr, aux, jnp.exp, at_least_int=True)
|
|
1181
|
+
|
|
1182
|
+
def _jax_ln(self, expr, aux):
|
|
1183
|
+
return self._jax_unary_helper(expr, aux, jnp.ln, at_least_int=True)
|
|
1184
|
+
|
|
1185
|
+
def _jax_sqrt(self, expr, aux):
|
|
1186
|
+
return self._jax_unary_helper(expr, aux, jnp.sqrt, at_least_int=True)
|
|
1187
|
+
|
|
1188
|
+
def _jax_lngamma(self, expr, aux):
|
|
1189
|
+
return self._jax_unary_helper(expr, aux, scipy.special.gammaln, at_least_int=True)
|
|
1190
|
+
|
|
1191
|
+
def _jax_gamma(self, expr, aux):
|
|
1192
|
+
return self._jax_unary_helper(expr, aux, scipy.special.gamma, at_least_int=True)
|
|
1193
|
+
|
|
1194
|
+
def _jax_div(self, expr, aux):
|
|
1195
|
+
return self._jax_binary_helper(expr, aux, jnp.floor_divide, at_least_int=True)
|
|
1196
|
+
|
|
1197
|
+
def _jax_mod(self, expr, aux):
|
|
1198
|
+
return self._jax_binary_helper(expr, aux, jnp.mod, at_least_int=True)
|
|
1199
|
+
|
|
1200
|
+
def _jax_fmod(self, expr, aux):
|
|
1201
|
+
return self._jax_binary_helper(expr, aux, jnp.mod, at_least_int=True)
|
|
1202
|
+
|
|
1203
|
+
def _jax_min(self, expr, aux):
|
|
1204
|
+
return self._jax_binary_helper(expr, aux, jnp.minimum, at_least_int=True)
|
|
1205
|
+
|
|
1206
|
+
def _jax_max(self, expr, aux):
|
|
1207
|
+
return self._jax_binary_helper(expr, aux, jnp.maximum, at_least_int=True)
|
|
1208
|
+
|
|
1209
|
+
def _jax_pow(self, expr, aux):
|
|
1210
|
+
return self._jax_binary_helper(expr, aux, jnp.power, at_least_int=True)
|
|
1211
|
+
|
|
1212
|
+
def _jax_log(self, expr, aux):
|
|
1213
|
+
def log_op(x, y):
|
|
1214
|
+
return jnp.log(x) / jnp.log(y)
|
|
1215
|
+
return self._jax_binary_helper(expr, aux, log_op, at_least_int=True)
|
|
1216
|
+
|
|
1217
|
+
def _jax_hypot(self, expr, aux):
|
|
1218
|
+
return self._jax_binary_helper(expr, aux, jnp.hypot, at_least_int=True)
|
|
1219
|
+
|
|
1220
|
+
# ===========================================================================
|
|
1221
|
+
# external function
|
|
1222
|
+
# ===========================================================================
|
|
1223
|
+
|
|
1224
|
+
def _jax_pyfunc(self, expr, aux):
|
|
939
1225
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
940
1226
|
|
|
941
1227
|
# get the Python function by name
|
|
@@ -957,25 +1243,21 @@ class JaxRDDLCompiler:
|
|
|
957
1243
|
require_dims = self.rddl.object_counts(captured_types)
|
|
958
1244
|
|
|
959
1245
|
# compile the inputs to the function
|
|
960
|
-
jax_inputs = [self._jax(arg,
|
|
1246
|
+
jax_inputs = [self._jax(arg, aux) for arg in args]
|
|
961
1247
|
|
|
962
1248
|
# compile the function evaluation function
|
|
963
|
-
def _jax_wrapped_external_function(
|
|
1249
|
+
def _jax_wrapped_external_function(fls, nfls, params, key):
|
|
964
1250
|
|
|
965
1251
|
# evaluate inputs to the function
|
|
966
1252
|
# first dimensions are non-captured vars in outer scope followed by all the _
|
|
967
1253
|
error = NORMAL
|
|
968
1254
|
flat_samples = []
|
|
969
1255
|
for jax_expr in jax_inputs:
|
|
970
|
-
sample, key, err, params = jax_expr(
|
|
971
|
-
|
|
972
|
-
first_dim = 1
|
|
973
|
-
for dim in shape[:num_free_vars]:
|
|
974
|
-
first_dim *= dim
|
|
975
|
-
new_shape = (first_dim,) + shape[num_free_vars:]
|
|
1256
|
+
sample, key, err, params = jax_expr(fls, nfls, params, key)
|
|
1257
|
+
new_shape = (-1,) + jnp.shape(sample)[num_free_vars:]
|
|
976
1258
|
flat_sample = jnp.reshape(sample, new_shape)
|
|
977
1259
|
flat_samples.append(flat_sample)
|
|
978
|
-
error
|
|
1260
|
+
error = error | err
|
|
979
1261
|
|
|
980
1262
|
# now all the inputs have dimensions equal to (k,) + the number of _ occurences
|
|
981
1263
|
# k is the number of possible non-captured object combinations
|
|
@@ -986,7 +1268,8 @@ class JaxRDDLCompiler:
|
|
|
986
1268
|
if not isinstance(sample, jnp.ndarray):
|
|
987
1269
|
raise ValueError(
|
|
988
1270
|
f'Output of external Python function <{pyfunc_name}> '
|
|
989
|
-
f'is not a JAX array.\n' + print_stack_trace(expr)
|
|
1271
|
+
f'is not a JAX array.\n' + print_stack_trace(expr)
|
|
1272
|
+
)
|
|
990
1273
|
|
|
991
1274
|
pyfunc_dims = jnp.shape(sample)[1:]
|
|
992
1275
|
if len(require_dims) != len(pyfunc_dims):
|
|
@@ -994,14 +1277,16 @@ class JaxRDDLCompiler:
|
|
|
994
1277
|
f'External Python function <{pyfunc_name}> returned array with '
|
|
995
1278
|
f'{len(pyfunc_dims)} dimensions, which does not match the '
|
|
996
1279
|
f'number of captured parameter(s) {len(require_dims)}.\n' +
|
|
997
|
-
print_stack_trace(expr)
|
|
1280
|
+
print_stack_trace(expr)
|
|
1281
|
+
)
|
|
998
1282
|
for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
|
|
999
1283
|
if require_dim != actual_dim:
|
|
1000
1284
|
raise ValueError(
|
|
1001
1285
|
f'External Python function <{pyfunc_name}> returned array with '
|
|
1002
1286
|
f'{actual_dim} elements for captured parameter <{param}>, '
|
|
1003
1287
|
f'which does not match the number of objects {require_dim}.\n' +
|
|
1004
|
-
print_stack_trace(expr)
|
|
1288
|
+
print_stack_trace(expr)
|
|
1289
|
+
)
|
|
1005
1290
|
|
|
1006
1291
|
# unravel the combinations k back into their original dimensions
|
|
1007
1292
|
sample = jnp.reshape(sample, free_dims + pyfunc_dims)
|
|
@@ -1017,111 +1302,75 @@ class JaxRDDLCompiler:
|
|
|
1017
1302
|
# control flow
|
|
1018
1303
|
# ===========================================================================
|
|
1019
1304
|
|
|
1020
|
-
def _jax_control(self, expr,
|
|
1305
|
+
def _jax_control(self, expr, aux):
|
|
1306
|
+
JaxRDDLCompiler._check_valid_op(expr, {'if', 'switch'})
|
|
1021
1307
|
_, op = expr.etype
|
|
1022
1308
|
if op == 'if':
|
|
1023
|
-
return self._jax_if(expr,
|
|
1309
|
+
return self._jax_if(expr, aux)
|
|
1024
1310
|
elif op == 'switch':
|
|
1025
|
-
return self._jax_switch(expr,
|
|
1026
|
-
|
|
1027
|
-
raise RDDLNotImplementedError(
|
|
1028
|
-
f'Control operator {op} is not supported.\n' + print_stack_trace(expr))
|
|
1311
|
+
return self._jax_switch(expr, aux)
|
|
1029
1312
|
|
|
1030
|
-
def _jax_if(self, expr,
|
|
1313
|
+
def _jax_if(self, expr, aux):
|
|
1031
1314
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
1032
1315
|
JaxRDDLCompiler._check_num_args(expr, 3)
|
|
1033
1316
|
pred, if_true, if_false = expr.args
|
|
1034
1317
|
|
|
1035
|
-
# if predicate is non-fluent, always use the exact operation
|
|
1036
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1037
|
-
if_op = self.EXACT_OPS['control']['if']
|
|
1038
|
-
else:
|
|
1039
|
-
if_op = self.OPS['control']['if']
|
|
1040
|
-
jax_op = if_op(expr.id, init_params)
|
|
1041
|
-
|
|
1042
1318
|
# recursively compile arguments
|
|
1043
|
-
jax_pred = self._jax(pred,
|
|
1044
|
-
jax_true = self._jax(if_true,
|
|
1045
|
-
jax_false = self._jax(if_false,
|
|
1046
|
-
|
|
1047
|
-
def _jax_wrapped_if_then_else(
|
|
1048
|
-
sample1, key, err1, params = jax_pred(
|
|
1049
|
-
sample2, key, err2, params = jax_true(
|
|
1050
|
-
sample3, key, err3, params = jax_false(
|
|
1051
|
-
sample
|
|
1319
|
+
jax_pred = self._jax(pred, aux)
|
|
1320
|
+
jax_true = self._jax(if_true, aux)
|
|
1321
|
+
jax_false = self._jax(if_false, aux)
|
|
1322
|
+
|
|
1323
|
+
def _jax_wrapped_if_then_else(fls, nfls, params, key):
|
|
1324
|
+
sample1, key, err1, params = jax_pred(fls, nfls, params, key)
|
|
1325
|
+
sample2, key, err2, params = jax_true(fls, nfls, params, key)
|
|
1326
|
+
sample3, key, err3, params = jax_false(fls, nfls, params, key)
|
|
1327
|
+
sample = jnp.where(sample1 > 0.5, sample2, sample3)
|
|
1052
1328
|
err = err1 | err2 | err3
|
|
1053
|
-
invalid_cast = jnp.logical_not(jnp.can_cast(sample1,
|
|
1054
|
-
err
|
|
1329
|
+
invalid_cast = jnp.logical_not(jnp.can_cast(sample1, jnp.bool_))
|
|
1330
|
+
err = err | (invalid_cast * ERR)
|
|
1055
1331
|
return sample, key, err, params
|
|
1056
|
-
|
|
1057
1332
|
return _jax_wrapped_if_then_else
|
|
1058
1333
|
|
|
1059
|
-
def _jax_switch(self, expr,
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
# if predicate is non-fluent, always use the exact operation
|
|
1063
|
-
# case conditions are currently only literals so they are non-fluent
|
|
1064
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1065
|
-
switch_op = self.EXACT_OPS['control']['switch']
|
|
1066
|
-
else:
|
|
1067
|
-
switch_op = self.OPS['control']['switch']
|
|
1068
|
-
jax_op = switch_op(expr.id, init_params)
|
|
1069
|
-
|
|
1334
|
+
def _jax_switch(self, expr, aux):
|
|
1335
|
+
|
|
1070
1336
|
# recursively compile predicate
|
|
1071
|
-
|
|
1337
|
+
pred = expr.args[0]
|
|
1338
|
+
jax_pred = self._jax(pred, aux)
|
|
1072
1339
|
|
|
1073
1340
|
# recursively compile cases
|
|
1074
1341
|
cases, default = self.traced.cached_sim_info(expr)
|
|
1075
|
-
jax_default = None if default is None else self._jax(default,
|
|
1076
|
-
jax_cases = [
|
|
1077
|
-
|
|
1342
|
+
jax_default = None if default is None else self._jax(default, aux)
|
|
1343
|
+
jax_cases = [
|
|
1344
|
+
(jax_default if _case is None else self._jax(_case, aux))
|
|
1345
|
+
for _case in cases
|
|
1346
|
+
]
|
|
1078
1347
|
|
|
1079
|
-
def _jax_wrapped_switch(
|
|
1348
|
+
def _jax_wrapped_switch(fls, nfls, params, key):
|
|
1080
1349
|
|
|
1081
1350
|
# sample predicate
|
|
1082
|
-
sample_pred, key, err, params = jax_pred(
|
|
1351
|
+
sample_pred, key, err, params = jax_pred(fls, nfls, params, key)
|
|
1083
1352
|
|
|
1084
1353
|
# sample cases
|
|
1085
|
-
sample_cases = [
|
|
1086
|
-
for
|
|
1087
|
-
|
|
1088
|
-
|
|
1354
|
+
sample_cases = []
|
|
1355
|
+
for jax_case in jax_cases:
|
|
1356
|
+
sample, key, err_case, params = jax_case(fls, nfls, params, key)
|
|
1357
|
+
sample_cases.append(sample)
|
|
1358
|
+
err = err | err_case
|
|
1089
1359
|
sample_cases = jnp.asarray(sample_cases)
|
|
1090
1360
|
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
1091
1361
|
|
|
1092
1362
|
# predicate (enum) is an integer - use it to extract from case array
|
|
1093
|
-
|
|
1363
|
+
sample_pred = jnp.asarray(sample_pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1364
|
+
sample = jnp.take_along_axis(sample_cases, sample_pred, axis=0)
|
|
1365
|
+
assert sample.shape[0] == 1
|
|
1366
|
+
sample = sample[0, ...]
|
|
1094
1367
|
return sample, key, err, params
|
|
1095
|
-
|
|
1096
1368
|
return _jax_wrapped_switch
|
|
1097
1369
|
|
|
1098
1370
|
# ===========================================================================
|
|
1099
1371
|
# random variables
|
|
1100
1372
|
# ===========================================================================
|
|
1101
1373
|
|
|
1102
|
-
# distributions with complete reparameterization support:
|
|
1103
|
-
# KronDelta: complete
|
|
1104
|
-
# DiracDelta: complete
|
|
1105
|
-
# Uniform: complete
|
|
1106
|
-
# Bernoulli: complete (subclass uses Gumbel-softmax)
|
|
1107
|
-
# Normal: complete
|
|
1108
|
-
# Exponential: complete
|
|
1109
|
-
# Geometric: complete
|
|
1110
|
-
# Weibull: complete
|
|
1111
|
-
# Pareto: complete
|
|
1112
|
-
# Gumbel: complete
|
|
1113
|
-
# Laplace: complete
|
|
1114
|
-
# Cauchy: complete
|
|
1115
|
-
# Gompertz: complete
|
|
1116
|
-
# Kumaraswamy: complete
|
|
1117
|
-
# Discrete: complete (subclass uses Gumbel-softmax)
|
|
1118
|
-
# UnnormDiscrete: complete (subclass uses Gumbel-softmax)
|
|
1119
|
-
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1120
|
-
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1121
|
-
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1122
|
-
# Binomial (subclass uses Gumbel-softmax or Normal approximation)
|
|
1123
|
-
# NegativeBinomial (subclass uses Poisson-Gamma mixture)
|
|
1124
|
-
|
|
1125
1374
|
# distributions which seem to support backpropagation (need more testing):
|
|
1126
1375
|
# Beta
|
|
1127
1376
|
# Student
|
|
@@ -1132,656 +1381,587 @@ class JaxRDDLCompiler:
|
|
|
1132
1381
|
# distributions with incomplete reparameterization support (TODO):
|
|
1133
1382
|
# Multinomial
|
|
1134
1383
|
|
|
1135
|
-
def _jax_random(self, expr,
|
|
1384
|
+
def _jax_random(self, expr, aux):
|
|
1136
1385
|
_, name = expr.etype
|
|
1137
1386
|
if name == 'KronDelta':
|
|
1138
|
-
return self._jax_kron(expr,
|
|
1387
|
+
return self._jax_kron(expr, aux)
|
|
1139
1388
|
elif name == 'DiracDelta':
|
|
1140
|
-
return self._jax_dirac(expr,
|
|
1389
|
+
return self._jax_dirac(expr, aux)
|
|
1141
1390
|
elif name == 'Uniform':
|
|
1142
|
-
return self._jax_uniform(expr,
|
|
1391
|
+
return self._jax_uniform(expr, aux)
|
|
1143
1392
|
elif name == 'Bernoulli':
|
|
1144
|
-
return self._jax_bernoulli(expr,
|
|
1393
|
+
return self._jax_bernoulli(expr, aux)
|
|
1145
1394
|
elif name == 'Normal':
|
|
1146
|
-
return self._jax_normal(expr,
|
|
1395
|
+
return self._jax_normal(expr, aux)
|
|
1147
1396
|
elif name == 'Poisson':
|
|
1148
|
-
return self._jax_poisson(expr,
|
|
1397
|
+
return self._jax_poisson(expr, aux)
|
|
1149
1398
|
elif name == 'Exponential':
|
|
1150
|
-
return self._jax_exponential(expr,
|
|
1399
|
+
return self._jax_exponential(expr, aux)
|
|
1151
1400
|
elif name == 'Weibull':
|
|
1152
|
-
return self._jax_weibull(expr,
|
|
1401
|
+
return self._jax_weibull(expr, aux)
|
|
1153
1402
|
elif name == 'Gamma':
|
|
1154
|
-
return self._jax_gamma(expr,
|
|
1403
|
+
return self._jax_gamma(expr, aux)
|
|
1155
1404
|
elif name == 'Binomial':
|
|
1156
|
-
return self._jax_binomial(expr,
|
|
1405
|
+
return self._jax_binomial(expr, aux)
|
|
1157
1406
|
elif name == 'NegativeBinomial':
|
|
1158
|
-
return self._jax_negative_binomial(expr,
|
|
1407
|
+
return self._jax_negative_binomial(expr, aux)
|
|
1159
1408
|
elif name == 'Beta':
|
|
1160
|
-
return self._jax_beta(expr,
|
|
1409
|
+
return self._jax_beta(expr, aux)
|
|
1161
1410
|
elif name == 'Geometric':
|
|
1162
|
-
return self._jax_geometric(expr,
|
|
1411
|
+
return self._jax_geometric(expr, aux)
|
|
1163
1412
|
elif name == 'Pareto':
|
|
1164
|
-
return self._jax_pareto(expr,
|
|
1413
|
+
return self._jax_pareto(expr, aux)
|
|
1165
1414
|
elif name == 'Student':
|
|
1166
|
-
return self._jax_student(expr,
|
|
1415
|
+
return self._jax_student(expr, aux)
|
|
1167
1416
|
elif name == 'Gumbel':
|
|
1168
|
-
return self._jax_gumbel(expr,
|
|
1417
|
+
return self._jax_gumbel(expr, aux)
|
|
1169
1418
|
elif name == 'Laplace':
|
|
1170
|
-
return self._jax_laplace(expr,
|
|
1419
|
+
return self._jax_laplace(expr, aux)
|
|
1171
1420
|
elif name == 'Cauchy':
|
|
1172
|
-
return self._jax_cauchy(expr,
|
|
1421
|
+
return self._jax_cauchy(expr, aux)
|
|
1173
1422
|
elif name == 'Gompertz':
|
|
1174
|
-
return self._jax_gompertz(expr,
|
|
1423
|
+
return self._jax_gompertz(expr, aux)
|
|
1175
1424
|
elif name == 'ChiSquare':
|
|
1176
|
-
return self._jax_chisquare(expr,
|
|
1425
|
+
return self._jax_chisquare(expr, aux)
|
|
1177
1426
|
elif name == 'Kumaraswamy':
|
|
1178
|
-
return self._jax_kumaraswamy(expr,
|
|
1427
|
+
return self._jax_kumaraswamy(expr, aux)
|
|
1179
1428
|
elif name == 'Discrete':
|
|
1180
|
-
return self._jax_discrete(expr,
|
|
1429
|
+
return self._jax_discrete(expr, aux, unnorm=False)
|
|
1181
1430
|
elif name == 'UnnormDiscrete':
|
|
1182
|
-
return self._jax_discrete(expr,
|
|
1431
|
+
return self._jax_discrete(expr, aux, unnorm=True)
|
|
1183
1432
|
elif name == 'Discrete(p)':
|
|
1184
|
-
return self._jax_discrete_pvar(expr,
|
|
1433
|
+
return self._jax_discrete_pvar(expr, aux, unnorm=False)
|
|
1185
1434
|
elif name == 'UnnormDiscrete(p)':
|
|
1186
|
-
return self._jax_discrete_pvar(expr,
|
|
1435
|
+
return self._jax_discrete_pvar(expr, aux, unnorm=True)
|
|
1187
1436
|
else:
|
|
1188
1437
|
raise RDDLNotImplementedError(
|
|
1189
1438
|
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1190
1439
|
|
|
1191
|
-
def _jax_kron(self, expr,
|
|
1440
|
+
def _jax_kron(self, expr, aux):
|
|
1192
1441
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
|
|
1193
1442
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1194
1443
|
arg, = expr.args
|
|
1195
|
-
arg = self._jax(arg,
|
|
1444
|
+
arg = self._jax(arg, aux)
|
|
1196
1445
|
|
|
1197
1446
|
# just check that the sample can be cast to int
|
|
1198
|
-
def _jax_wrapped_distribution_kron(
|
|
1199
|
-
sample, key, err, params = arg(
|
|
1447
|
+
def _jax_wrapped_distribution_kron(fls, nfls, params, key):
|
|
1448
|
+
sample, key, err, params = arg(fls, nfls, params, key)
|
|
1200
1449
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample, self.INT))
|
|
1201
|
-
err
|
|
1450
|
+
err = err | (invalid_cast * ERR)
|
|
1202
1451
|
return sample, key, err, params
|
|
1203
|
-
|
|
1204
1452
|
return _jax_wrapped_distribution_kron
|
|
1205
1453
|
|
|
1206
|
-
def _jax_dirac(self, expr,
|
|
1454
|
+
def _jax_dirac(self, expr, aux):
|
|
1207
1455
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1208
1456
|
arg, = expr.args
|
|
1209
|
-
arg = self._jax(arg,
|
|
1457
|
+
arg = self._jax(arg, aux, dtype=self.REAL)
|
|
1210
1458
|
return arg
|
|
1211
1459
|
|
|
1212
|
-
def _jax_uniform(self, expr,
|
|
1460
|
+
def _jax_uniform(self, expr, aux):
|
|
1213
1461
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_UNIFORM']
|
|
1214
1462
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1215
1463
|
|
|
1216
1464
|
arg_lb, arg_ub = expr.args
|
|
1217
|
-
jax_lb = self._jax(arg_lb,
|
|
1218
|
-
jax_ub = self._jax(arg_ub,
|
|
1465
|
+
jax_lb = self._jax(arg_lb, aux)
|
|
1466
|
+
jax_ub = self._jax(arg_ub, aux)
|
|
1219
1467
|
|
|
1220
1468
|
# reparameterization trick U(a, b) = a + (b - a) * U(0, 1)
|
|
1221
|
-
def _jax_wrapped_distribution_uniform(
|
|
1222
|
-
lb, key, err1, params = jax_lb(
|
|
1223
|
-
ub, key, err2, params = jax_ub(
|
|
1469
|
+
def _jax_wrapped_distribution_uniform(fls, nfls, params, key):
|
|
1470
|
+
lb, key, err1, params = jax_lb(fls, nfls, params, key)
|
|
1471
|
+
ub, key, err2, params = jax_ub(fls, nfls, params, key)
|
|
1224
1472
|
key, subkey = random.split(key)
|
|
1225
1473
|
U = random.uniform(key=subkey, shape=jnp.shape(lb), dtype=self.REAL)
|
|
1226
1474
|
sample = lb + (ub - lb) * U
|
|
1227
1475
|
out_of_bounds = jnp.logical_not(jnp.all(lb <= ub))
|
|
1228
1476
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1229
|
-
return sample, key, err, params
|
|
1230
|
-
|
|
1477
|
+
return sample, key, err, params
|
|
1231
1478
|
return _jax_wrapped_distribution_uniform
|
|
1232
1479
|
|
|
1233
|
-
def _jax_normal(self, expr,
|
|
1480
|
+
def _jax_normal(self, expr, aux):
|
|
1234
1481
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NORMAL']
|
|
1235
1482
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1236
1483
|
|
|
1237
1484
|
arg_mean, arg_var = expr.args
|
|
1238
|
-
jax_mean = self._jax(arg_mean,
|
|
1239
|
-
jax_var = self._jax(arg_var,
|
|
1485
|
+
jax_mean = self._jax(arg_mean, aux)
|
|
1486
|
+
jax_var = self._jax(arg_var, aux)
|
|
1240
1487
|
|
|
1241
1488
|
# reparameterization trick N(m, s^2) = m + s * N(0, 1)
|
|
1242
|
-
def _jax_wrapped_distribution_normal(
|
|
1243
|
-
mean, key, err1, params = jax_mean(
|
|
1244
|
-
var, key, err2, params = jax_var(
|
|
1489
|
+
def _jax_wrapped_distribution_normal(fls, nfls, params, key):
|
|
1490
|
+
mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1491
|
+
var, key, err2, params = jax_var(fls, nfls, params, key)
|
|
1245
1492
|
std = jnp.sqrt(var)
|
|
1246
1493
|
key, subkey = random.split(key)
|
|
1247
1494
|
Z = random.normal(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1248
1495
|
sample = mean + std * Z
|
|
1249
1496
|
out_of_bounds = jnp.logical_not(jnp.all(var >= 0))
|
|
1250
1497
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1251
|
-
return sample, key, err, params
|
|
1252
|
-
|
|
1498
|
+
return sample, key, err, params
|
|
1253
1499
|
return _jax_wrapped_distribution_normal
|
|
1254
1500
|
|
|
1255
|
-
def _jax_exponential(self, expr,
|
|
1501
|
+
def _jax_exponential(self, expr, aux):
|
|
1256
1502
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_EXPONENTIAL']
|
|
1257
1503
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1258
1504
|
|
|
1259
1505
|
arg_scale, = expr.args
|
|
1260
|
-
jax_scale = self._jax(arg_scale,
|
|
1506
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1261
1507
|
|
|
1262
1508
|
# reparameterization trick Exp(s) = s * Exp(1)
|
|
1263
|
-
def _jax_wrapped_distribution_exp(
|
|
1264
|
-
scale, key, err, params = jax_scale(
|
|
1509
|
+
def _jax_wrapped_distribution_exp(fls, nfls, params, key):
|
|
1510
|
+
scale, key, err, params = jax_scale(fls, nfls, params, key)
|
|
1265
1511
|
key, subkey = random.split(key)
|
|
1266
|
-
|
|
1267
|
-
sample = scale *
|
|
1512
|
+
exp = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1513
|
+
sample = scale * exp
|
|
1268
1514
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1269
|
-
err
|
|
1270
|
-
return sample, key, err, params
|
|
1271
|
-
|
|
1515
|
+
err = err | (out_of_bounds * ERR)
|
|
1516
|
+
return sample, key, err, params
|
|
1272
1517
|
return _jax_wrapped_distribution_exp
|
|
1273
1518
|
|
|
1274
|
-
def _jax_weibull(self, expr,
|
|
1519
|
+
def _jax_weibull(self, expr, aux):
|
|
1275
1520
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_WEIBULL']
|
|
1276
1521
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1277
1522
|
|
|
1278
1523
|
arg_shape, arg_scale = expr.args
|
|
1279
|
-
jax_shape = self._jax(arg_shape,
|
|
1280
|
-
jax_scale = self._jax(arg_scale,
|
|
1524
|
+
jax_shape = self._jax(arg_shape, aux)
|
|
1525
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1281
1526
|
|
|
1282
1527
|
# reparameterization trick W(s, r) = r * (-ln(1 - U(0, 1))) ** (1 / s)
|
|
1283
|
-
def _jax_wrapped_distribution_weibull(
|
|
1284
|
-
shape, key, err1, params = jax_shape(
|
|
1285
|
-
scale, key, err2, params = jax_scale(
|
|
1528
|
+
def _jax_wrapped_distribution_weibull(fls, nfls, params, key):
|
|
1529
|
+
shape, key, err1, params = jax_shape(fls, nfls, params, key)
|
|
1530
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1286
1531
|
key, subkey = random.split(key)
|
|
1287
1532
|
sample = random.weibull_min(
|
|
1288
1533
|
key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
|
|
1289
|
-
out_of_bounds = jnp.logical_not(jnp.all((shape > 0
|
|
1534
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
|
|
1290
1535
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1291
|
-
return sample, key, err, params
|
|
1292
|
-
|
|
1536
|
+
return sample, key, err, params
|
|
1293
1537
|
return _jax_wrapped_distribution_weibull
|
|
1294
1538
|
|
|
1295
|
-
def _jax_bernoulli(self, expr,
|
|
1539
|
+
def _jax_bernoulli(self, expr, aux):
|
|
1296
1540
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1297
1541
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1298
1542
|
arg_prob, = expr.args
|
|
1299
1543
|
|
|
1300
|
-
# if probability is non-fluent, always use the exact operation
|
|
1301
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1302
|
-
bern_op = self.EXACT_OPS['sampling']['Bernoulli']
|
|
1303
|
-
else:
|
|
1304
|
-
bern_op = self.OPS['sampling']['Bernoulli']
|
|
1305
|
-
jax_op = bern_op(expr.id, init_params)
|
|
1306
|
-
|
|
1307
1544
|
# recursively compile arguments
|
|
1308
|
-
jax_prob = self._jax(arg_prob,
|
|
1545
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1309
1546
|
|
|
1310
|
-
def _jax_wrapped_distribution_bernoulli(
|
|
1311
|
-
prob, key, err, params = jax_prob(
|
|
1547
|
+
def _jax_wrapped_distribution_bernoulli(fls, nfls, params, key):
|
|
1548
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1312
1549
|
key, subkey = random.split(key)
|
|
1313
|
-
sample
|
|
1314
|
-
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0
|
|
1315
|
-
err
|
|
1316
|
-
return sample, key, err, params
|
|
1317
|
-
|
|
1550
|
+
sample = random.bernoulli(subkey, prob)
|
|
1551
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1552
|
+
err = err | (out_of_bounds * ERR)
|
|
1553
|
+
return sample, key, err, params
|
|
1318
1554
|
return _jax_wrapped_distribution_bernoulli
|
|
1319
1555
|
|
|
1320
|
-
def _jax_poisson(self, expr,
|
|
1556
|
+
def _jax_poisson(self, expr, aux):
|
|
1321
1557
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1322
1558
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1323
1559
|
arg_rate, = expr.args
|
|
1324
1560
|
|
|
1325
|
-
# if rate is non-fluent, always use the exact operation
|
|
1326
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1327
|
-
poisson_op = self.EXACT_OPS['sampling']['Poisson']
|
|
1328
|
-
else:
|
|
1329
|
-
poisson_op = self.OPS['sampling']['Poisson']
|
|
1330
|
-
jax_op = poisson_op(expr.id, init_params)
|
|
1331
|
-
|
|
1332
1561
|
# recursively compile arguments
|
|
1333
|
-
jax_rate = self._jax(arg_rate,
|
|
1562
|
+
jax_rate = self._jax(arg_rate, aux)
|
|
1334
1563
|
|
|
1335
1564
|
# uses the implicit JAX subroutine
|
|
1336
|
-
def _jax_wrapped_distribution_poisson(
|
|
1337
|
-
rate, key, err, params = jax_rate(
|
|
1565
|
+
def _jax_wrapped_distribution_poisson(fls, nfls, params, key):
|
|
1566
|
+
rate, key, err, params = jax_rate(fls, nfls, params, key)
|
|
1338
1567
|
key, subkey = random.split(key)
|
|
1339
|
-
sample
|
|
1568
|
+
sample = random.poisson(key=subkey, lam=rate, dtype=self.INT)
|
|
1340
1569
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1341
|
-
err
|
|
1342
|
-
return sample, key, err, params
|
|
1343
|
-
|
|
1570
|
+
err = err | (out_of_bounds * ERR)
|
|
1571
|
+
return sample, key, err, params
|
|
1344
1572
|
return _jax_wrapped_distribution_poisson
|
|
1345
1573
|
|
|
1346
|
-
def _jax_gamma(self, expr,
|
|
1574
|
+
def _jax_gamma(self, expr, aux):
|
|
1347
1575
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GAMMA']
|
|
1348
1576
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1349
1577
|
|
|
1350
1578
|
arg_shape, arg_scale = expr.args
|
|
1351
|
-
jax_shape = self._jax(arg_shape,
|
|
1352
|
-
jax_scale = self._jax(arg_scale,
|
|
1579
|
+
jax_shape = self._jax(arg_shape, aux)
|
|
1580
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1353
1581
|
|
|
1354
1582
|
# partial reparameterization trick Gamma(s, r) = r * Gamma(s, 1)
|
|
1355
1583
|
# uses the implicit JAX subroutine for Gamma(s, 1)
|
|
1356
|
-
def _jax_wrapped_distribution_gamma(
|
|
1357
|
-
shape, key, err1, params = jax_shape(
|
|
1358
|
-
scale, key, err2, params = jax_scale(
|
|
1584
|
+
def _jax_wrapped_distribution_gamma(fls, nfls, params, key):
|
|
1585
|
+
shape, key, err1, params = jax_shape(fls, nfls, params, key)
|
|
1586
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1359
1587
|
key, subkey = random.split(key)
|
|
1360
1588
|
Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
|
|
1361
1589
|
sample = scale * Gamma
|
|
1362
|
-
out_of_bounds = jnp.logical_not(jnp.all((shape > 0
|
|
1590
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
|
|
1363
1591
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1364
|
-
return sample, key, err, params
|
|
1365
|
-
|
|
1592
|
+
return sample, key, err, params
|
|
1366
1593
|
return _jax_wrapped_distribution_gamma
|
|
1367
1594
|
|
|
1368
|
-
def _jax_binomial(self, expr,
|
|
1595
|
+
def _jax_binomial(self, expr, aux):
|
|
1369
1596
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1370
1597
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1371
1598
|
arg_trials, arg_prob = expr.args
|
|
1372
1599
|
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1376
|
-
and not self.traced.cached_is_fluent(arg_prob):
|
|
1377
|
-
bin_op = self.EXACT_OPS['sampling']['Binomial']
|
|
1378
|
-
else:
|
|
1379
|
-
bin_op = self.OPS['sampling']['Binomial']
|
|
1380
|
-
jax_op = bin_op(expr.id, init_params)
|
|
1381
|
-
|
|
1382
|
-
jax_trials = self._jax(arg_trials, init_params)
|
|
1383
|
-
jax_prob = self._jax(arg_prob, init_params)
|
|
1600
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1601
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1384
1602
|
|
|
1385
1603
|
# uses reduction for constant trials
|
|
1386
|
-
def _jax_wrapped_distribution_binomial(
|
|
1387
|
-
trials, key, err2, params = jax_trials(
|
|
1388
|
-
prob, key, err1, params = jax_prob(
|
|
1604
|
+
def _jax_wrapped_distribution_binomial(fls, nfls, params, key):
|
|
1605
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1606
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1389
1607
|
key, subkey = random.split(key)
|
|
1390
|
-
|
|
1608
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1609
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1610
|
+
sample = random.binomial(key=subkey, n=trials, p=prob, dtype=self.REAL)
|
|
1611
|
+
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1391
1612
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1392
|
-
(prob >= 0
|
|
1613
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
|
|
1393
1614
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1394
|
-
return sample, key, err, params
|
|
1395
|
-
|
|
1615
|
+
return sample, key, err, params
|
|
1396
1616
|
return _jax_wrapped_distribution_binomial
|
|
1397
1617
|
|
|
1398
|
-
def _jax_negative_binomial(self, expr,
|
|
1618
|
+
def _jax_negative_binomial(self, expr, aux):
|
|
1399
1619
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1400
1620
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1401
1621
|
arg_trials, arg_prob = expr.args
|
|
1402
1622
|
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1406
|
-
and not self.traced.cached_is_fluent(arg_prob):
|
|
1407
|
-
negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
|
|
1408
|
-
else:
|
|
1409
|
-
negbin_op = self.OPS['sampling']['NegativeBinomial']
|
|
1410
|
-
jax_op = negbin_op(expr.id, init_params)
|
|
1411
|
-
|
|
1412
|
-
jax_trials = self._jax(arg_trials, init_params)
|
|
1413
|
-
jax_prob = self._jax(arg_prob, init_params)
|
|
1623
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1624
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1414
1625
|
|
|
1415
|
-
# uses
|
|
1416
|
-
def _jax_wrapped_distribution_negative_binomial(
|
|
1417
|
-
trials, key, err2, params = jax_trials(
|
|
1418
|
-
prob, key, err1, params = jax_prob(
|
|
1626
|
+
# uses tensorflow-probability
|
|
1627
|
+
def _jax_wrapped_distribution_negative_binomial(fls, nfls, params, key):
|
|
1628
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1629
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1419
1630
|
key, subkey = random.split(key)
|
|
1420
|
-
|
|
1631
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1632
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1633
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1. - prob)
|
|
1634
|
+
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1421
1635
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1422
|
-
(prob >= 0
|
|
1636
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
|
|
1423
1637
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1424
|
-
return sample, key, err, params
|
|
1425
|
-
|
|
1638
|
+
return sample, key, err, params
|
|
1426
1639
|
return _jax_wrapped_distribution_negative_binomial
|
|
1427
1640
|
|
|
1428
|
-
def _jax_beta(self, expr,
|
|
1641
|
+
def _jax_beta(self, expr, aux):
|
|
1429
1642
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BETA']
|
|
1430
1643
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1431
1644
|
|
|
1432
1645
|
arg_shape, arg_rate = expr.args
|
|
1433
|
-
jax_shape = self._jax(arg_shape,
|
|
1434
|
-
jax_rate = self._jax(arg_rate,
|
|
1646
|
+
jax_shape = self._jax(arg_shape, aux)
|
|
1647
|
+
jax_rate = self._jax(arg_rate, aux)
|
|
1435
1648
|
|
|
1436
1649
|
# uses the implicit JAX subroutine
|
|
1437
|
-
def _jax_wrapped_distribution_beta(
|
|
1438
|
-
shape, key, err1, params = jax_shape(
|
|
1439
|
-
rate, key, err2, params = jax_rate(
|
|
1650
|
+
def _jax_wrapped_distribution_beta(fls, nfls, params, key):
|
|
1651
|
+
shape, key, err1, params = jax_shape(fls, nfls, params, key)
|
|
1652
|
+
rate, key, err2, params = jax_rate(fls, nfls, params, key)
|
|
1440
1653
|
key, subkey = random.split(key)
|
|
1441
1654
|
sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
|
|
1442
|
-
out_of_bounds = jnp.logical_not(jnp.all((shape > 0
|
|
1655
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, rate > 0)))
|
|
1443
1656
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1444
|
-
return sample, key, err, params
|
|
1445
|
-
|
|
1657
|
+
return sample, key, err, params
|
|
1446
1658
|
return _jax_wrapped_distribution_beta
|
|
1447
1659
|
|
|
1448
|
-
def _jax_geometric(self, expr,
|
|
1660
|
+
def _jax_geometric(self, expr, aux):
|
|
1449
1661
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1450
1662
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1451
1663
|
arg_prob, = expr.args
|
|
1452
1664
|
|
|
1453
|
-
# if prob is non-fluent, always use the exact operation
|
|
1454
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1455
|
-
geom_op = self.EXACT_OPS['sampling']['Geometric']
|
|
1456
|
-
else:
|
|
1457
|
-
geom_op = self.OPS['sampling']['Geometric']
|
|
1458
|
-
jax_op = geom_op(expr.id, init_params)
|
|
1459
|
-
|
|
1460
1665
|
# recursively compile arguments
|
|
1461
|
-
jax_prob = self._jax(arg_prob,
|
|
1666
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1462
1667
|
|
|
1463
|
-
def _jax_wrapped_distribution_geometric(
|
|
1464
|
-
prob, key, err, params = jax_prob(
|
|
1668
|
+
def _jax_wrapped_distribution_geometric(fls, nfls, params, key):
|
|
1669
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1465
1670
|
key, subkey = random.split(key)
|
|
1466
|
-
sample
|
|
1467
|
-
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0
|
|
1468
|
-
err
|
|
1469
|
-
return sample, key, err, params
|
|
1470
|
-
|
|
1671
|
+
sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
|
|
1672
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1673
|
+
err = err | (out_of_bounds * ERR)
|
|
1674
|
+
return sample, key, err, params
|
|
1471
1675
|
return _jax_wrapped_distribution_geometric
|
|
1472
1676
|
|
|
1473
|
-
def _jax_pareto(self, expr,
|
|
1677
|
+
def _jax_pareto(self, expr, aux):
|
|
1474
1678
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_PARETO']
|
|
1475
1679
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1476
1680
|
|
|
1477
1681
|
arg_shape, arg_scale = expr.args
|
|
1478
|
-
jax_shape = self._jax(arg_shape,
|
|
1479
|
-
jax_scale = self._jax(arg_scale,
|
|
1682
|
+
jax_shape = self._jax(arg_shape, aux)
|
|
1683
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1480
1684
|
|
|
1481
1685
|
# partial reparameterization trick Pareto(s, r) = r * Pareto(s, 1)
|
|
1482
1686
|
# uses the implicit JAX subroutine for Pareto(s, 1)
|
|
1483
|
-
def _jax_wrapped_distribution_pareto(
|
|
1484
|
-
shape, key, err1, params = jax_shape(
|
|
1485
|
-
scale, key, err2, params = jax_scale(
|
|
1687
|
+
def _jax_wrapped_distribution_pareto(fls, nfls, params, key):
|
|
1688
|
+
shape, key, err1, params = jax_shape(fls, nfls, params, key)
|
|
1689
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1486
1690
|
key, subkey = random.split(key)
|
|
1487
1691
|
sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
|
|
1488
|
-
out_of_bounds = jnp.logical_not(jnp.all((shape > 0
|
|
1692
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
|
|
1489
1693
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1490
|
-
return sample, key, err, params
|
|
1491
|
-
|
|
1694
|
+
return sample, key, err, params
|
|
1492
1695
|
return _jax_wrapped_distribution_pareto
|
|
1493
1696
|
|
|
1494
|
-
def _jax_student(self, expr,
|
|
1697
|
+
def _jax_student(self, expr, aux):
|
|
1495
1698
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_STUDENT']
|
|
1496
1699
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1497
1700
|
|
|
1498
1701
|
arg_df, = expr.args
|
|
1499
|
-
jax_df = self._jax(arg_df,
|
|
1702
|
+
jax_df = self._jax(arg_df, aux)
|
|
1500
1703
|
|
|
1501
1704
|
# uses the implicit JAX subroutine for student(df)
|
|
1502
|
-
def _jax_wrapped_distribution_t(
|
|
1503
|
-
df, key, err, params = jax_df(
|
|
1705
|
+
def _jax_wrapped_distribution_t(fls, nfls, params, key):
|
|
1706
|
+
df, key, err, params = jax_df(fls, nfls, params, key)
|
|
1504
1707
|
key, subkey = random.split(key)
|
|
1505
1708
|
sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1506
1709
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1507
|
-
err
|
|
1508
|
-
return sample, key, err, params
|
|
1509
|
-
|
|
1710
|
+
err = err | (out_of_bounds * ERR)
|
|
1711
|
+
return sample, key, err, params
|
|
1510
1712
|
return _jax_wrapped_distribution_t
|
|
1511
1713
|
|
|
1512
|
-
def _jax_gumbel(self, expr,
|
|
1714
|
+
def _jax_gumbel(self, expr, aux):
|
|
1513
1715
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GUMBEL']
|
|
1514
1716
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1515
1717
|
|
|
1516
1718
|
arg_mean, arg_scale = expr.args
|
|
1517
|
-
jax_mean = self._jax(arg_mean,
|
|
1518
|
-
jax_scale = self._jax(arg_scale,
|
|
1719
|
+
jax_mean = self._jax(arg_mean, aux)
|
|
1720
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1519
1721
|
|
|
1520
1722
|
# reparameterization trick Gumbel(m, s) = m + s * Gumbel(0, 1)
|
|
1521
|
-
def _jax_wrapped_distribution_gumbel(
|
|
1522
|
-
mean, key, err1, params = jax_mean(
|
|
1523
|
-
scale, key, err2, params = jax_scale(
|
|
1723
|
+
def _jax_wrapped_distribution_gumbel(fls, nfls, params, key):
|
|
1724
|
+
mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1725
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1524
1726
|
key, subkey = random.split(key)
|
|
1525
|
-
|
|
1526
|
-
sample = mean + scale *
|
|
1727
|
+
g = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1728
|
+
sample = mean + scale * g
|
|
1527
1729
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1528
1730
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1529
|
-
return sample, key, err, params
|
|
1530
|
-
|
|
1731
|
+
return sample, key, err, params
|
|
1531
1732
|
return _jax_wrapped_distribution_gumbel
|
|
1532
1733
|
|
|
1533
|
-
def _jax_laplace(self, expr,
|
|
1734
|
+
def _jax_laplace(self, expr, aux):
|
|
1534
1735
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_LAPLACE']
|
|
1535
1736
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1536
1737
|
|
|
1537
1738
|
arg_mean, arg_scale = expr.args
|
|
1538
|
-
jax_mean = self._jax(arg_mean,
|
|
1539
|
-
jax_scale = self._jax(arg_scale,
|
|
1739
|
+
jax_mean = self._jax(arg_mean, aux)
|
|
1740
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1540
1741
|
|
|
1541
1742
|
# reparameterization trick Laplace(m, s) = m + s * Laplace(0, 1)
|
|
1542
|
-
def _jax_wrapped_distribution_laplace(
|
|
1543
|
-
mean, key, err1, params = jax_mean(
|
|
1544
|
-
scale, key, err2, params = jax_scale(
|
|
1743
|
+
def _jax_wrapped_distribution_laplace(fls, nfls, params, key):
|
|
1744
|
+
mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1745
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1545
1746
|
key, subkey = random.split(key)
|
|
1546
|
-
|
|
1547
|
-
sample = mean + scale *
|
|
1747
|
+
lp = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1748
|
+
sample = mean + scale * lp
|
|
1548
1749
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1549
1750
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1550
|
-
return sample, key, err, params
|
|
1551
|
-
|
|
1751
|
+
return sample, key, err, params
|
|
1552
1752
|
return _jax_wrapped_distribution_laplace
|
|
1553
1753
|
|
|
1554
|
-
def _jax_cauchy(self, expr,
|
|
1754
|
+
def _jax_cauchy(self, expr, aux):
|
|
1555
1755
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CAUCHY']
|
|
1556
1756
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1557
1757
|
|
|
1558
1758
|
arg_mean, arg_scale = expr.args
|
|
1559
|
-
jax_mean = self._jax(arg_mean,
|
|
1560
|
-
jax_scale = self._jax(arg_scale,
|
|
1759
|
+
jax_mean = self._jax(arg_mean, aux)
|
|
1760
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1561
1761
|
|
|
1562
1762
|
# reparameterization trick Cauchy(m, s) = m + s * Cauchy(0, 1)
|
|
1563
|
-
def _jax_wrapped_distribution_cauchy(
|
|
1564
|
-
mean, key, err1, params = jax_mean(
|
|
1565
|
-
scale, key, err2, params = jax_scale(
|
|
1763
|
+
def _jax_wrapped_distribution_cauchy(fls, nfls, params, key):
|
|
1764
|
+
mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1765
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1566
1766
|
key, subkey = random.split(key)
|
|
1567
|
-
|
|
1568
|
-
sample = mean + scale *
|
|
1767
|
+
cauchy = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1768
|
+
sample = mean + scale * cauchy
|
|
1569
1769
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1570
1770
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1571
|
-
return sample, key, err, params
|
|
1572
|
-
|
|
1771
|
+
return sample, key, err, params
|
|
1573
1772
|
return _jax_wrapped_distribution_cauchy
|
|
1574
1773
|
|
|
1575
|
-
def _jax_gompertz(self, expr,
|
|
1774
|
+
def _jax_gompertz(self, expr, aux):
|
|
1576
1775
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GOMPERTZ']
|
|
1577
1776
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1578
1777
|
|
|
1579
1778
|
arg_shape, arg_scale = expr.args
|
|
1580
|
-
jax_shape = self._jax(arg_shape,
|
|
1581
|
-
jax_scale = self._jax(arg_scale,
|
|
1779
|
+
jax_shape = self._jax(arg_shape, aux)
|
|
1780
|
+
jax_scale = self._jax(arg_scale, aux)
|
|
1582
1781
|
|
|
1583
1782
|
# reparameterization trick Gompertz(s, r) = ln(1 - log(U(0, 1)) / s) / r
|
|
1584
|
-
def _jax_wrapped_distribution_gompertz(
|
|
1585
|
-
shape, key, err1, params = jax_shape(
|
|
1586
|
-
scale, key, err2, params = jax_scale(
|
|
1783
|
+
def _jax_wrapped_distribution_gompertz(fls, nfls, params, key):
|
|
1784
|
+
shape, key, err1, params = jax_shape(fls, nfls, params, key)
|
|
1785
|
+
scale, key, err2, params = jax_scale(fls, nfls, params, key)
|
|
1587
1786
|
key, subkey = random.split(key)
|
|
1588
1787
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1589
1788
|
sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
|
|
1590
|
-
out_of_bounds = jnp.logical_not(jnp.all((shape > 0
|
|
1789
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(shape > 0, scale > 0)))
|
|
1591
1790
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1592
|
-
return sample, key, err, params
|
|
1593
|
-
|
|
1791
|
+
return sample, key, err, params
|
|
1594
1792
|
return _jax_wrapped_distribution_gompertz
|
|
1595
1793
|
|
|
1596
|
-
def _jax_chisquare(self, expr,
|
|
1794
|
+
def _jax_chisquare(self, expr, aux):
|
|
1597
1795
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CHISQUARE']
|
|
1598
1796
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1599
1797
|
|
|
1600
1798
|
arg_df, = expr.args
|
|
1601
|
-
jax_df = self._jax(arg_df,
|
|
1799
|
+
jax_df = self._jax(arg_df, aux)
|
|
1602
1800
|
|
|
1603
1801
|
# use the fact that ChiSquare(df) = Gamma(df/2, 2)
|
|
1604
|
-
def _jax_wrapped_distribution_chisquare(
|
|
1605
|
-
df, key, err1, params = jax_df(
|
|
1802
|
+
def _jax_wrapped_distribution_chisquare(fls, nfls, params, key):
|
|
1803
|
+
df, key, err1, params = jax_df(fls, nfls, params, key)
|
|
1606
1804
|
key, subkey = random.split(key)
|
|
1607
|
-
shape =
|
|
1608
|
-
|
|
1609
|
-
sample = 2.0 * Gamma
|
|
1805
|
+
shape = 0.5 * df
|
|
1806
|
+
sample = 2.0 * random.gamma(key=subkey, a=shape, dtype=self.REAL)
|
|
1610
1807
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1611
1808
|
err = err1 | (out_of_bounds * ERR)
|
|
1612
|
-
return sample, key, err, params
|
|
1613
|
-
|
|
1809
|
+
return sample, key, err, params
|
|
1614
1810
|
return _jax_wrapped_distribution_chisquare
|
|
1615
1811
|
|
|
1616
|
-
def _jax_kumaraswamy(self, expr,
|
|
1812
|
+
def _jax_kumaraswamy(self, expr, aux):
|
|
1617
1813
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KUMARASWAMY']
|
|
1618
1814
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1619
1815
|
|
|
1620
1816
|
arg_a, arg_b = expr.args
|
|
1621
|
-
jax_a = self._jax(arg_a,
|
|
1622
|
-
jax_b = self._jax(arg_b,
|
|
1817
|
+
jax_a = self._jax(arg_a, aux)
|
|
1818
|
+
jax_b = self._jax(arg_b, aux)
|
|
1623
1819
|
|
|
1624
1820
|
# uses the reparameterization K(a, b) = (1 - (1 - U(0, 1))^{1/b})^{1/a}
|
|
1625
|
-
def _jax_wrapped_distribution_kumaraswamy(
|
|
1626
|
-
a, key, err1, params = jax_a(
|
|
1627
|
-
b, key, err2, params = jax_b(
|
|
1821
|
+
def _jax_wrapped_distribution_kumaraswamy(fls, nfls, params, key):
|
|
1822
|
+
a, key, err1, params = jax_a(fls, nfls, params, key)
|
|
1823
|
+
b, key, err2, params = jax_b(fls, nfls, params, key)
|
|
1628
1824
|
key, subkey = random.split(key)
|
|
1629
1825
|
U = random.uniform(key=subkey, shape=jnp.shape(a), dtype=self.REAL)
|
|
1630
1826
|
sample = jnp.power(1.0 - jnp.power(U, 1.0 / b), 1.0 / a)
|
|
1631
|
-
out_of_bounds = jnp.logical_not(jnp.all((a > 0
|
|
1827
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(a > 0, b > 0)))
|
|
1632
1828
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1633
|
-
return sample, key, err, params
|
|
1634
|
-
|
|
1829
|
+
return sample, key, err, params
|
|
1635
1830
|
return _jax_wrapped_distribution_kumaraswamy
|
|
1636
1831
|
|
|
1637
1832
|
# ===========================================================================
|
|
1638
1833
|
# random variables with enum support
|
|
1639
1834
|
# ===========================================================================
|
|
1640
1835
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1836
|
+
@staticmethod
|
|
1837
|
+
def _jax_update_discrete_oob_error(err, prob):
|
|
1643
1838
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
error = NORMAL
|
|
1662
|
-
prob = [None] * len(jax_probs)
|
|
1663
|
-
for (i, jax_prob) in enumerate(jax_probs):
|
|
1664
|
-
prob[i], key, error_pdf, params = jax_prob(x, params, key)
|
|
1665
|
-
error |= error_pdf
|
|
1839
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1840
|
+
jnp.all(prob >= 0),
|
|
1841
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1842
|
+
))
|
|
1843
|
+
error = err | (out_of_bounds * ERR)
|
|
1844
|
+
return error
|
|
1845
|
+
|
|
1846
|
+
def _jax_discrete_prob(self, jax_probs, unnormalized):
|
|
1847
|
+
def _jax_wrapped_calc_discrete_prob(fls, nfls, params, key):
|
|
1848
|
+
|
|
1849
|
+
# calculate probability expressions
|
|
1850
|
+
error = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
1851
|
+
prob = []
|
|
1852
|
+
for jax_prob in jax_probs:
|
|
1853
|
+
sample, key, error_pdf, params = jax_prob(fls, nfls, params, key)
|
|
1854
|
+
prob.append(sample)
|
|
1855
|
+
error = error | error_pdf
|
|
1666
1856
|
prob = jnp.stack(prob, axis=-1)
|
|
1667
|
-
|
|
1857
|
+
|
|
1858
|
+
# normalize them if required
|
|
1859
|
+
if unnormalized:
|
|
1668
1860
|
normalizer = jnp.sum(prob, axis=-1, keepdims=True)
|
|
1669
1861
|
prob = prob / normalizer
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
))
|
|
1678
|
-
error |= (out_of_bounds * ERR)
|
|
1679
|
-
return sample, key, error, params
|
|
1862
|
+
return prob, key, error, params
|
|
1863
|
+
return _jax_wrapped_calc_discrete_prob
|
|
1864
|
+
|
|
1865
|
+
def _jax_discrete(self, expr, aux, unnorm):
|
|
1866
|
+
ordered_args = self.traced.cached_sim_info(expr)
|
|
1867
|
+
jax_probs = [self._jax(arg, aux) for arg in ordered_args]
|
|
1868
|
+
prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
|
|
1680
1869
|
|
|
1870
|
+
def _jax_wrapped_distribution_discrete(fls, nfls, params, key):
|
|
1871
|
+
prob, key, error, params = prob_fn(fls, nfls, params, key)
|
|
1872
|
+
key, subkey = random.split(key)
|
|
1873
|
+
sample = random.categorical(key=subkey, logits=jnp.log(prob), axis=-1)
|
|
1874
|
+
error = JaxRDDLCompiler._jax_update_discrete_oob_error(error, prob)
|
|
1875
|
+
return sample, key, error, params
|
|
1681
1876
|
return _jax_wrapped_distribution_discrete
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1877
|
+
|
|
1878
|
+
@staticmethod
|
|
1879
|
+
def _jax_discrete_pvar_prob(jax_probs, unnormalized):
|
|
1880
|
+
def _jax_wrapped_calc_discrete_prob(fls, nfls, params, key):
|
|
1881
|
+
prob, key, error, params = jax_probs(fls, nfls, params, key)
|
|
1882
|
+
if unnormalized:
|
|
1883
|
+
normalizer = jnp.sum(prob, axis=-1, keepdims=True)
|
|
1884
|
+
prob = prob / normalizer
|
|
1885
|
+
return prob, key, error, params
|
|
1886
|
+
return _jax_wrapped_calc_discrete_prob
|
|
1887
|
+
|
|
1888
|
+
def _jax_discrete_pvar(self, expr, aux, unnorm):
|
|
1685
1889
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1686
1890
|
_, args = expr.args
|
|
1687
1891
|
arg, = args
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
|
|
1691
|
-
discrete_op = self.EXACT_OPS['sampling']['Discrete']
|
|
1692
|
-
else:
|
|
1693
|
-
discrete_op = self.OPS['sampling']['Discrete']
|
|
1694
|
-
jax_op = discrete_op(expr.id, init_params)
|
|
1695
|
-
|
|
1696
|
-
# compile probability function
|
|
1697
|
-
jax_probs = self._jax(arg, init_params)
|
|
1892
|
+
jax_probs = self._jax(arg, aux)
|
|
1893
|
+
prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
|
|
1698
1894
|
|
|
1699
|
-
def _jax_wrapped_distribution_discrete_pvar(
|
|
1700
|
-
|
|
1701
|
-
# sample probabilities
|
|
1702
|
-
prob, key, error, params = jax_probs(x, params, key)
|
|
1703
|
-
if unnorm:
|
|
1704
|
-
normalizer = jnp.sum(prob, axis=-1, keepdims=True)
|
|
1705
|
-
prob = prob / normalizer
|
|
1706
|
-
|
|
1707
|
-
# dispatch to sampling subroutine
|
|
1895
|
+
def _jax_wrapped_distribution_discrete_pvar(fls, nfls, params, key):
|
|
1896
|
+
prob, key, error, params = prob_fn(fls, nfls, params, key)
|
|
1708
1897
|
key, subkey = random.split(key)
|
|
1709
|
-
sample
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1713
|
-
))
|
|
1714
|
-
error |= (out_of_bounds * ERR)
|
|
1715
|
-
return sample, key, error, params
|
|
1716
|
-
|
|
1898
|
+
sample = random.categorical(key=subkey, logits=jnp.log(prob), axis=-1)
|
|
1899
|
+
error = JaxRDDLCompiler._jax_update_discrete_oob_error(error, prob)
|
|
1900
|
+
return sample, key, error, params
|
|
1717
1901
|
return _jax_wrapped_distribution_discrete_pvar
|
|
1718
1902
|
|
|
1719
1903
|
# ===========================================================================
|
|
1720
1904
|
# random vectors
|
|
1721
1905
|
# ===========================================================================
|
|
1722
1906
|
|
|
1723
|
-
def _jax_random_vector(self, expr,
|
|
1907
|
+
def _jax_random_vector(self, expr, aux):
|
|
1724
1908
|
_, name = expr.etype
|
|
1725
1909
|
if name == 'MultivariateNormal':
|
|
1726
|
-
return self._jax_multivariate_normal(expr,
|
|
1910
|
+
return self._jax_multivariate_normal(expr, aux)
|
|
1727
1911
|
elif name == 'MultivariateStudent':
|
|
1728
|
-
return self._jax_multivariate_student(expr,
|
|
1912
|
+
return self._jax_multivariate_student(expr, aux)
|
|
1729
1913
|
elif name == 'Dirichlet':
|
|
1730
|
-
return self._jax_dirichlet(expr,
|
|
1914
|
+
return self._jax_dirichlet(expr, aux)
|
|
1731
1915
|
elif name == 'Multinomial':
|
|
1732
|
-
return self._jax_multinomial(expr,
|
|
1916
|
+
return self._jax_multinomial(expr, aux)
|
|
1733
1917
|
else:
|
|
1734
1918
|
raise RDDLNotImplementedError(
|
|
1735
1919
|
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1736
1920
|
|
|
1737
|
-
def _jax_multivariate_normal(self, expr,
|
|
1921
|
+
def _jax_multivariate_normal(self, expr, aux):
|
|
1738
1922
|
_, args = expr.args
|
|
1739
1923
|
mean, cov = args
|
|
1740
|
-
jax_mean = self._jax(mean,
|
|
1741
|
-
jax_cov = self._jax(cov,
|
|
1924
|
+
jax_mean = self._jax(mean, aux)
|
|
1925
|
+
jax_cov = self._jax(cov, aux)
|
|
1742
1926
|
index, = self.traced.cached_sim_info(expr)
|
|
1743
1927
|
|
|
1744
1928
|
# reparameterization trick MN(m, LL') = LZ + m, where Z ~ Normal(0, 1)
|
|
1745
|
-
def _jax_wrapped_distribution_multivariate_normal(
|
|
1929
|
+
def _jax_wrapped_distribution_multivariate_normal(fls, nfls, params, key):
|
|
1746
1930
|
|
|
1747
1931
|
# sample the mean and covariance
|
|
1748
|
-
sample_mean, key, err1, params = jax_mean(
|
|
1749
|
-
sample_cov, key, err2, params = jax_cov(
|
|
1932
|
+
sample_mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1933
|
+
sample_cov, key, err2, params = jax_cov(fls, nfls, params, key)
|
|
1750
1934
|
|
|
1751
1935
|
# sample Normal(0, 1)
|
|
1752
1936
|
key, subkey = random.split(key)
|
|
1753
1937
|
Z = random.normal(
|
|
1754
|
-
key=subkey,
|
|
1755
|
-
shape=jnp.shape(sample_mean) + (1,),
|
|
1756
|
-
dtype=self.REAL
|
|
1757
|
-
)
|
|
1938
|
+
key=subkey, shape=jnp.shape(sample_mean) + (1,), dtype=self.REAL)
|
|
1758
1939
|
|
|
1759
1940
|
# compute L s.t. cov = L * L' and reparameterize
|
|
1760
1941
|
L = jnp.linalg.cholesky(sample_cov)
|
|
1761
1942
|
sample = jnp.matmul(L, Z)[..., 0] + sample_mean
|
|
1762
1943
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1763
1944
|
err = err1 | err2
|
|
1764
|
-
return sample, key, err, params
|
|
1765
|
-
|
|
1945
|
+
return sample, key, err, params
|
|
1766
1946
|
return _jax_wrapped_distribution_multivariate_normal
|
|
1767
1947
|
|
|
1768
|
-
def _jax_multivariate_student(self, expr,
|
|
1948
|
+
def _jax_multivariate_student(self, expr, aux):
|
|
1769
1949
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTIVARIATE_STUDENT']
|
|
1770
1950
|
|
|
1771
1951
|
_, args = expr.args
|
|
1772
1952
|
mean, cov, df = args
|
|
1773
|
-
jax_mean = self._jax(mean,
|
|
1774
|
-
jax_cov = self._jax(cov,
|
|
1775
|
-
jax_df = self._jax(df,
|
|
1953
|
+
jax_mean = self._jax(mean, aux)
|
|
1954
|
+
jax_cov = self._jax(cov, aux)
|
|
1955
|
+
jax_df = self._jax(df, aux)
|
|
1776
1956
|
index, = self.traced.cached_sim_info(expr)
|
|
1777
1957
|
|
|
1778
1958
|
# reparameterization trick MN(m, LL') = LZ + m, where Z ~ StudentT(0, 1)
|
|
1779
|
-
def _jax_wrapped_distribution_multivariate_student(
|
|
1959
|
+
def _jax_wrapped_distribution_multivariate_student(fls, nfls, params, key):
|
|
1780
1960
|
|
|
1781
1961
|
# sample the mean and covariance and degrees of freedom
|
|
1782
|
-
sample_mean, key, err1, params = jax_mean(
|
|
1783
|
-
sample_cov, key, err2, params = jax_cov(
|
|
1784
|
-
sample_df, key, err3, params = jax_df(
|
|
1962
|
+
sample_mean, key, err1, params = jax_mean(fls, nfls, params, key)
|
|
1963
|
+
sample_cov, key, err2, params = jax_cov(fls, nfls, params, key)
|
|
1964
|
+
sample_df, key, err3, params = jax_df(fls, nfls, params, key)
|
|
1785
1965
|
out_of_bounds = jnp.logical_not(jnp.all(sample_df > 0))
|
|
1786
1966
|
|
|
1787
1967
|
# sample StudentT(0, 1, df) -- broadcast df to same shape as cov
|
|
@@ -1800,43 +1980,41 @@ class JaxRDDLCompiler:
|
|
|
1800
1980
|
sample = jnp.matmul(L, Z)[..., 0] + sample_mean
|
|
1801
1981
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1802
1982
|
error = err1 | err2 | err3 | (out_of_bounds * ERR)
|
|
1803
|
-
return sample, key, error, params
|
|
1804
|
-
|
|
1983
|
+
return sample, key, error, params
|
|
1805
1984
|
return _jax_wrapped_distribution_multivariate_student
|
|
1806
1985
|
|
|
1807
|
-
def _jax_dirichlet(self, expr,
|
|
1986
|
+
def _jax_dirichlet(self, expr, aux):
|
|
1808
1987
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DIRICHLET']
|
|
1809
1988
|
|
|
1810
1989
|
_, args = expr.args
|
|
1811
1990
|
alpha, = args
|
|
1812
|
-
jax_alpha = self._jax(alpha,
|
|
1991
|
+
jax_alpha = self._jax(alpha, aux)
|
|
1813
1992
|
index, = self.traced.cached_sim_info(expr)
|
|
1814
1993
|
|
|
1815
1994
|
# sample Gamma(alpha_i, 1) and normalize across i
|
|
1816
|
-
def _jax_wrapped_distribution_dirichlet(
|
|
1817
|
-
alpha, key, error, params = jax_alpha(
|
|
1995
|
+
def _jax_wrapped_distribution_dirichlet(fls, nfls, params, key):
|
|
1996
|
+
alpha, key, error, params = jax_alpha(fls, nfls, params, key)
|
|
1818
1997
|
out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
|
|
1819
|
-
error
|
|
1998
|
+
error = error | (out_of_bounds * ERR)
|
|
1820
1999
|
key, subkey = random.split(key)
|
|
1821
|
-
|
|
1822
|
-
sample =
|
|
2000
|
+
gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
|
|
2001
|
+
sample = gamma / jnp.sum(gamma, axis=-1, keepdims=True)
|
|
1823
2002
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1824
|
-
return sample, key, error, params
|
|
1825
|
-
|
|
2003
|
+
return sample, key, error, params
|
|
1826
2004
|
return _jax_wrapped_distribution_dirichlet
|
|
1827
2005
|
|
|
1828
|
-
def _jax_multinomial(self, expr,
|
|
2006
|
+
def _jax_multinomial(self, expr, aux):
|
|
1829
2007
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTINOMIAL']
|
|
1830
2008
|
|
|
1831
2009
|
_, args = expr.args
|
|
1832
2010
|
trials, prob = args
|
|
1833
|
-
jax_trials = self._jax(trials,
|
|
1834
|
-
jax_prob = self._jax(prob,
|
|
2011
|
+
jax_trials = self._jax(trials, aux)
|
|
2012
|
+
jax_prob = self._jax(prob, aux)
|
|
1835
2013
|
index, = self.traced.cached_sim_info(expr)
|
|
1836
2014
|
|
|
1837
|
-
def _jax_wrapped_distribution_multinomial(
|
|
1838
|
-
trials, key, err1, params = jax_trials(
|
|
1839
|
-
prob, key, err2, params = jax_prob(
|
|
2015
|
+
def _jax_wrapped_distribution_multinomial(fls, nfls, params, key):
|
|
2016
|
+
trials, key, err1, params = jax_trials(fls, nfls, params, key)
|
|
2017
|
+
prob, key, err2, params = jax_prob(fls, nfls, params, key)
|
|
1840
2018
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1841
2019
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1842
2020
|
key, subkey = random.split(key)
|
|
@@ -1844,70 +2022,66 @@ class JaxRDDLCompiler:
|
|
|
1844
2022
|
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1845
2023
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1846
2024
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1847
|
-
(
|
|
1848
|
-
|
|
1849
|
-
|
|
2025
|
+
jnp.logical_and(
|
|
2026
|
+
jnp.logical_and(prob >= 0, jnp.allclose(jnp.sum(prob, axis=-1), 1.)),
|
|
2027
|
+
trials >= 0
|
|
2028
|
+
)
|
|
1850
2029
|
))
|
|
1851
2030
|
error = err1 | err2 | (out_of_bounds * ERR)
|
|
1852
|
-
return sample, key, error, params
|
|
1853
|
-
|
|
2031
|
+
return sample, key, error, params
|
|
1854
2032
|
return _jax_wrapped_distribution_multinomial
|
|
1855
2033
|
|
|
1856
2034
|
# ===========================================================================
|
|
1857
2035
|
# matrix algebra
|
|
1858
2036
|
# ===========================================================================
|
|
1859
2037
|
|
|
1860
|
-
def _jax_matrix(self, expr,
|
|
2038
|
+
def _jax_matrix(self, expr, aux):
|
|
1861
2039
|
_, op = expr.etype
|
|
1862
2040
|
if op == 'det':
|
|
1863
|
-
return self._jax_matrix_det(expr,
|
|
2041
|
+
return self._jax_matrix_det(expr, aux)
|
|
1864
2042
|
elif op == 'inverse':
|
|
1865
|
-
return self._jax_matrix_inv(expr,
|
|
2043
|
+
return self._jax_matrix_inv(expr, aux, pseudo=False)
|
|
1866
2044
|
elif op == 'pinverse':
|
|
1867
|
-
return self._jax_matrix_inv(expr,
|
|
2045
|
+
return self._jax_matrix_inv(expr, aux, pseudo=True)
|
|
1868
2046
|
elif op == 'cholesky':
|
|
1869
|
-
return self._jax_matrix_cholesky(expr,
|
|
2047
|
+
return self._jax_matrix_cholesky(expr, aux)
|
|
1870
2048
|
else:
|
|
1871
2049
|
raise RDDLNotImplementedError(
|
|
1872
|
-
f'Matrix operation {op} is not supported.\n' +
|
|
1873
|
-
print_stack_trace(expr))
|
|
2050
|
+
f'Matrix operation {op} is not supported.\n' + print_stack_trace(expr))
|
|
1874
2051
|
|
|
1875
|
-
def _jax_matrix_det(self, expr,
|
|
1876
|
-
|
|
1877
|
-
jax_arg = self._jax(arg,
|
|
2052
|
+
def _jax_matrix_det(self, expr, aux):
|
|
2053
|
+
arg = expr.args[-1]
|
|
2054
|
+
jax_arg = self._jax(arg, aux)
|
|
1878
2055
|
|
|
1879
|
-
def _jax_wrapped_matrix_operation_det(
|
|
1880
|
-
sample_arg, key, error, params = jax_arg(
|
|
2056
|
+
def _jax_wrapped_matrix_operation_det(fls, nfls, params, key):
|
|
2057
|
+
sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
|
|
1881
2058
|
sample = jnp.linalg.det(sample_arg)
|
|
1882
|
-
return sample, key, error, params
|
|
1883
|
-
|
|
2059
|
+
return sample, key, error, params
|
|
1884
2060
|
return _jax_wrapped_matrix_operation_det
|
|
1885
2061
|
|
|
1886
|
-
def _jax_matrix_inv(self, expr,
|
|
2062
|
+
def _jax_matrix_inv(self, expr, aux, pseudo):
|
|
1887
2063
|
_, arg = expr.args
|
|
1888
|
-
jax_arg = self._jax(arg,
|
|
2064
|
+
jax_arg = self._jax(arg, aux)
|
|
1889
2065
|
indices = self.traced.cached_sim_info(expr)
|
|
1890
2066
|
op = jnp.linalg.pinv if pseudo else jnp.linalg.inv
|
|
1891
2067
|
|
|
1892
|
-
def _jax_wrapped_matrix_operation_inv(
|
|
1893
|
-
sample_arg, key, error, params = jax_arg(
|
|
2068
|
+
def _jax_wrapped_matrix_operation_inv(fls, nfls, params, key):
|
|
2069
|
+
sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
|
|
1894
2070
|
sample = op(sample_arg)
|
|
1895
2071
|
sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
|
|
1896
|
-
return sample, key, error, params
|
|
1897
|
-
|
|
2072
|
+
return sample, key, error, params
|
|
1898
2073
|
return _jax_wrapped_matrix_operation_inv
|
|
1899
2074
|
|
|
1900
|
-
def _jax_matrix_cholesky(self, expr,
|
|
2075
|
+
def _jax_matrix_cholesky(self, expr, aux):
|
|
1901
2076
|
_, arg = expr.args
|
|
1902
|
-
jax_arg = self._jax(arg,
|
|
2077
|
+
jax_arg = self._jax(arg, aux)
|
|
1903
2078
|
indices = self.traced.cached_sim_info(expr)
|
|
1904
2079
|
op = jnp.linalg.cholesky
|
|
1905
2080
|
|
|
1906
|
-
def _jax_wrapped_matrix_operation_cholesky(
|
|
1907
|
-
sample_arg, key, error, params = jax_arg(
|
|
2081
|
+
def _jax_wrapped_matrix_operation_cholesky(fls, nfls, params, key):
|
|
2082
|
+
sample_arg, key, error, params = jax_arg(fls, nfls, params, key)
|
|
1908
2083
|
sample = op(sample_arg)
|
|
1909
2084
|
sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
|
|
1910
|
-
return sample, key, error, params
|
|
1911
|
-
|
|
2085
|
+
return sample, key, error, params
|
|
1912
2086
|
return _jax_wrapped_matrix_operation_cholesky
|
|
1913
|
-
|
|
2087
|
+
|