pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -21,6 +21,8 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
21
21
|
from pyRDDLGym.core.debug.logger import Logger
|
|
22
22
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
23
23
|
|
|
24
|
+
from pyRDDLGym_jax.core.logic import ExactLogic
|
|
25
|
+
|
|
24
26
|
# more robust approach - if user does not have this or broken try to continue
|
|
25
27
|
try:
|
|
26
28
|
from tensorflow_probability.substrates import jax as tfp
|
|
@@ -32,109 +34,6 @@ except Exception:
|
|
|
32
34
|
tfp = None
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
# ===========================================================================
|
|
36
|
-
# EXACT RDDL TO JAX COMPILATION RULES
|
|
37
|
-
# ===========================================================================
|
|
38
|
-
|
|
39
|
-
def _function_unary_exact_named(op, name):
|
|
40
|
-
|
|
41
|
-
def _jax_wrapped_unary_fn_exact(x, param):
|
|
42
|
-
return op(x)
|
|
43
|
-
|
|
44
|
-
return _jax_wrapped_unary_fn_exact
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def _function_unary_exact_named_gamma():
|
|
48
|
-
|
|
49
|
-
def _jax_wrapped_unary_gamma_exact(x, param):
|
|
50
|
-
return jnp.exp(scipy.special.gammaln(x))
|
|
51
|
-
|
|
52
|
-
return _jax_wrapped_unary_gamma_exact
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _function_binary_exact_named(op, name):
|
|
56
|
-
|
|
57
|
-
def _jax_wrapped_binary_fn_exact(x, y, param):
|
|
58
|
-
return op(x, y)
|
|
59
|
-
|
|
60
|
-
return _jax_wrapped_binary_fn_exact
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _function_binary_exact_named_implies():
|
|
64
|
-
|
|
65
|
-
def _jax_wrapped_binary_implies_exact(x, y, param):
|
|
66
|
-
return jnp.logical_or(jnp.logical_not(x), y)
|
|
67
|
-
|
|
68
|
-
return _jax_wrapped_binary_implies_exact
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _function_binary_exact_named_log():
|
|
72
|
-
|
|
73
|
-
def _jax_wrapped_binary_log_exact(x, y, param):
|
|
74
|
-
return jnp.log(x) / jnp.log(y)
|
|
75
|
-
|
|
76
|
-
return _jax_wrapped_binary_log_exact
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _function_aggregation_exact_named(op, name):
|
|
80
|
-
|
|
81
|
-
def _jax_wrapped_aggregation_fn_exact(x, axis, param):
|
|
82
|
-
return op(x, axis=axis)
|
|
83
|
-
|
|
84
|
-
return _jax_wrapped_aggregation_fn_exact
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _function_if_exact_named():
|
|
88
|
-
|
|
89
|
-
def _jax_wrapped_if_exact(c, a, b, param):
|
|
90
|
-
return jnp.where(c > 0.5, a, b)
|
|
91
|
-
|
|
92
|
-
return _jax_wrapped_if_exact
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def _function_switch_exact_named():
|
|
96
|
-
|
|
97
|
-
def _jax_wrapped_switch_exact(pred, cases, param):
|
|
98
|
-
pred = pred[jnp.newaxis, ...]
|
|
99
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
100
|
-
assert sample.shape[0] == 1
|
|
101
|
-
return sample[0, ...]
|
|
102
|
-
|
|
103
|
-
return _jax_wrapped_switch_exact
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def _function_bernoulli_exact_named():
|
|
107
|
-
|
|
108
|
-
def _jax_wrapped_bernoulli_exact(key, prob, param):
|
|
109
|
-
return random.bernoulli(key, prob)
|
|
110
|
-
|
|
111
|
-
return _jax_wrapped_bernoulli_exact
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def _function_discrete_exact_named():
|
|
115
|
-
|
|
116
|
-
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
117
|
-
return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
118
|
-
|
|
119
|
-
return _jax_wrapped_discrete_exact
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def _function_poisson_exact_named():
|
|
123
|
-
|
|
124
|
-
def _jax_wrapped_poisson_exact(key, rate, param):
|
|
125
|
-
return random.poisson(key=key, lam=rate, dtype=jnp.int64)
|
|
126
|
-
|
|
127
|
-
return _jax_wrapped_poisson_exact
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def _function_geometric_exact_named():
|
|
131
|
-
|
|
132
|
-
def _jax_wrapped_geometric_exact(key, prob, param):
|
|
133
|
-
return random.geometric(key=key, p=prob, dtype=jnp.int64)
|
|
134
|
-
|
|
135
|
-
return _jax_wrapped_geometric_exact
|
|
136
|
-
|
|
137
|
-
|
|
138
37
|
class JaxRDDLCompiler:
|
|
139
38
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
140
39
|
All operations are identical to their numpy equivalents.
|
|
@@ -145,88 +44,96 @@ class JaxRDDLCompiler:
|
|
|
145
44
|
# ===========================================================================
|
|
146
45
|
# EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
|
|
147
46
|
# ===========================================================================
|
|
148
|
-
|
|
149
|
-
EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
|
|
150
47
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
}
|
|
48
|
+
@staticmethod
|
|
49
|
+
def wrap_logic(func):
|
|
50
|
+
def exact_func(id, init_params):
|
|
51
|
+
return func
|
|
52
|
+
return exact_func
|
|
157
53
|
|
|
54
|
+
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
|
|
55
|
+
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
56
|
+
'+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
+
'-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
+
'*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
+
'/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
60
|
+
}
|
|
61
|
+
|
|
158
62
|
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
159
|
-
'>=':
|
|
160
|
-
'<=':
|
|
161
|
-
'<':
|
|
162
|
-
'>':
|
|
163
|
-
'==':
|
|
164
|
-
'~=':
|
|
165
|
-
}
|
|
166
|
-
|
|
63
|
+
'>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
+
'<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
+
'<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
+
'>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
+
'==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
+
'~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
167
72
|
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
168
|
-
'^':
|
|
169
|
-
'&':
|
|
170
|
-
'|':
|
|
171
|
-
'~':
|
|
172
|
-
'=>':
|
|
173
|
-
'<=>':
|
|
174
|
-
}
|
|
175
|
-
|
|
176
|
-
EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
|
|
73
|
+
'^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
+
'&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
+
'|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
+
'~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
+
'=>': wrap_logic(ExactLogic.exact_binary_implies),
|
|
78
|
+
'<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
|
|
79
|
+
}
|
|
177
80
|
|
|
178
81
|
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
179
|
-
'sum':
|
|
180
|
-
'avg':
|
|
181
|
-
'prod':
|
|
182
|
-
'minimum':
|
|
183
|
-
'maximum':
|
|
184
|
-
'forall':
|
|
185
|
-
'exists':
|
|
186
|
-
'argmin':
|
|
187
|
-
'argmax':
|
|
82
|
+
'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
+
'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
+
'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
+
'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
+
'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
+
'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
+
'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
+
'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
+
'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
|
|
188
91
|
}
|
|
189
92
|
|
|
190
93
|
EXACT_RDDL_TO_JAX_UNARY = {
|
|
191
|
-
'abs':
|
|
192
|
-
'sgn':
|
|
193
|
-
'round':
|
|
194
|
-
'floor':
|
|
195
|
-
'ceil':
|
|
196
|
-
'cos':
|
|
197
|
-
'sin':
|
|
198
|
-
'tan':
|
|
199
|
-
'acos':
|
|
200
|
-
'asin':
|
|
201
|
-
'atan':
|
|
202
|
-
'cosh':
|
|
203
|
-
'sinh':
|
|
204
|
-
'tanh':
|
|
205
|
-
'exp':
|
|
206
|
-
'ln':
|
|
207
|
-
'sqrt':
|
|
208
|
-
'lngamma':
|
|
209
|
-
'gamma':
|
|
210
|
-
}
|
|
211
|
-
|
|
94
|
+
'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
+
'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
+
'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
+
'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
+
'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
+
'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
+
'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
+
'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
+
'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
+
'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
+
'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
+
'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
+
'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
+
'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
+
'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
+
'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
+
'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
+
'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
+
'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _jax_wrapped_calc_log_exact(x, y, params):
|
|
117
|
+
return jnp.log(x) / jnp.log(y), params
|
|
118
|
+
|
|
212
119
|
EXACT_RDDL_TO_JAX_BINARY = {
|
|
213
|
-
'div':
|
|
214
|
-
'mod':
|
|
215
|
-
'fmod':
|
|
216
|
-
'min':
|
|
217
|
-
'max':
|
|
218
|
-
'pow':
|
|
219
|
-
'log':
|
|
220
|
-
'hypot':
|
|
120
|
+
'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
+
'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
+
'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
+
'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
+
'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
+
'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
+
'log': wrap_logic(_jax_wrapped_calc_log_exact),
|
|
127
|
+
'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
221
128
|
}
|
|
222
129
|
|
|
223
|
-
EXACT_RDDL_TO_JAX_IF =
|
|
224
|
-
EXACT_RDDL_TO_JAX_SWITCH =
|
|
130
|
+
EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
|
|
131
|
+
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
|
|
225
132
|
|
|
226
|
-
EXACT_RDDL_TO_JAX_BERNOULLI =
|
|
227
|
-
EXACT_RDDL_TO_JAX_DISCRETE =
|
|
228
|
-
EXACT_RDDL_TO_JAX_POISSON =
|
|
229
|
-
EXACT_RDDL_TO_JAX_GEOMETRIC =
|
|
133
|
+
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
|
|
134
|
+
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
|
|
135
|
+
EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
|
|
136
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
|
|
230
137
|
|
|
231
138
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
232
139
|
allow_synchronous_state: bool=True,
|
|
@@ -251,18 +158,17 @@ class JaxRDDLCompiler:
|
|
|
251
158
|
if use64bit:
|
|
252
159
|
self.INT = jnp.int64
|
|
253
160
|
self.REAL = jnp.float64
|
|
254
|
-
jax.config.update('jax_enable_x64', True)
|
|
255
161
|
else:
|
|
256
162
|
self.INT = jnp.int32
|
|
257
163
|
self.REAL = jnp.float32
|
|
258
|
-
|
|
164
|
+
jax.config.update('jax_enable_x64', use64bit)
|
|
259
165
|
self.ONE = jnp.asarray(1, dtype=self.INT)
|
|
260
166
|
self.JAX_TYPES = {
|
|
261
167
|
'int': self.INT,
|
|
262
168
|
'real': self.REAL,
|
|
263
169
|
'bool': bool
|
|
264
170
|
}
|
|
265
|
-
|
|
171
|
+
|
|
266
172
|
# compile initial values
|
|
267
173
|
initializer = RDDLValueInitializer(rddl)
|
|
268
174
|
self.init_values = initializer.initialize()
|
|
@@ -314,15 +220,13 @@ class JaxRDDLCompiler:
|
|
|
314
220
|
to the log file
|
|
315
221
|
:param heading: the heading to print before compilation information
|
|
316
222
|
'''
|
|
317
|
-
|
|
318
|
-
self.invariants = self._compile_constraints(self.rddl.invariants,
|
|
319
|
-
self.preconditions = self._compile_constraints(self.rddl.preconditions,
|
|
320
|
-
self.terminations = self._compile_constraints(self.rddl.terminations,
|
|
321
|
-
self.cpfs = self._compile_cpfs(
|
|
322
|
-
self.reward = self._compile_reward(
|
|
323
|
-
self.model_params =
|
|
324
|
-
for (key, (value, *_)) in info[0].items()}
|
|
325
|
-
self.relaxations = info[1]
|
|
223
|
+
init_params = {}
|
|
224
|
+
self.invariants = self._compile_constraints(self.rddl.invariants, init_params)
|
|
225
|
+
self.preconditions = self._compile_constraints(self.rddl.preconditions, init_params)
|
|
226
|
+
self.terminations = self._compile_constraints(self.rddl.terminations, init_params)
|
|
227
|
+
self.cpfs = self._compile_cpfs(init_params)
|
|
228
|
+
self.reward = self._compile_reward(init_params)
|
|
229
|
+
self.model_params = init_params
|
|
326
230
|
|
|
327
231
|
if log_jax_expr and self.logger is not None:
|
|
328
232
|
printed = self.print_jax()
|
|
@@ -332,7 +236,7 @@ class JaxRDDLCompiler:
|
|
|
332
236
|
printed_invariants = '\n\n'.join(v for v in printed['invariants'])
|
|
333
237
|
printed_preconds = '\n\n'.join(v for v in printed['preconditions'])
|
|
334
238
|
printed_terminals = '\n\n'.join(v for v in printed['terminations'])
|
|
335
|
-
printed_params = '\n'.join(f'{k}: {v}' for (k, v) in
|
|
239
|
+
printed_params = '\n'.join(f'{k}: {v}' for (k, v) in init_params.items())
|
|
336
240
|
message = (
|
|
337
241
|
f'[info] {heading}\n'
|
|
338
242
|
f'[info] compiled JAX CPFs:\n\n'
|
|
@@ -350,21 +254,21 @@ class JaxRDDLCompiler:
|
|
|
350
254
|
)
|
|
351
255
|
self.logger.log(message)
|
|
352
256
|
|
|
353
|
-
def _compile_constraints(self, constraints,
|
|
354
|
-
return [self._jax(expr,
|
|
257
|
+
def _compile_constraints(self, constraints, init_params):
|
|
258
|
+
return [self._jax(expr, init_params, dtype=bool) for expr in constraints]
|
|
355
259
|
|
|
356
|
-
def _compile_cpfs(self,
|
|
260
|
+
def _compile_cpfs(self, init_params):
|
|
357
261
|
jax_cpfs = {}
|
|
358
262
|
for cpfs in self.levels.values():
|
|
359
263
|
for cpf in cpfs:
|
|
360
264
|
_, expr = self.rddl.cpfs[cpf]
|
|
361
265
|
prange = self.rddl.variable_ranges[cpf]
|
|
362
266
|
dtype = self.JAX_TYPES.get(prange, self.INT)
|
|
363
|
-
jax_cpfs[cpf] = self._jax(expr,
|
|
267
|
+
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=dtype)
|
|
364
268
|
return jax_cpfs
|
|
365
269
|
|
|
366
|
-
def _compile_reward(self,
|
|
367
|
-
return self._jax(self.rddl.reward,
|
|
270
|
+
def _compile_reward(self, init_params):
|
|
271
|
+
return self._jax(self.rddl.reward, init_params, dtype=self.REAL)
|
|
368
272
|
|
|
369
273
|
def _extract_inequality_constraint(self, expr):
|
|
370
274
|
result = []
|
|
@@ -392,7 +296,7 @@ class JaxRDDLCompiler:
|
|
|
392
296
|
result.extend(self._extract_equality_constraint(arg))
|
|
393
297
|
return result
|
|
394
298
|
|
|
395
|
-
def _jax_nonlinear_constraints(self):
|
|
299
|
+
def _jax_nonlinear_constraints(self, init_params):
|
|
396
300
|
rddl = self.rddl
|
|
397
301
|
|
|
398
302
|
# extract the non-box inequality constraints on actions
|
|
@@ -402,12 +306,12 @@ class JaxRDDLCompiler:
|
|
|
402
306
|
if not self.constraints.is_box_preconditions[i]]
|
|
403
307
|
|
|
404
308
|
# compile them to JAX and write as h(s, a) <= 0
|
|
405
|
-
|
|
309
|
+
jax_op = ExactLogic.exact_binary_function(jnp.subtract)
|
|
406
310
|
jax_inequalities = []
|
|
407
311
|
for (left, right) in inequalities:
|
|
408
|
-
jax_lhs = self._jax(left,
|
|
409
|
-
jax_rhs = self._jax(right,
|
|
410
|
-
jax_constr = self._jax_binary(jax_lhs, jax_rhs,
|
|
312
|
+
jax_lhs = self._jax(left, init_params)
|
|
313
|
+
jax_rhs = self._jax(right, init_params)
|
|
314
|
+
jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
|
|
411
315
|
jax_inequalities.append(jax_constr)
|
|
412
316
|
|
|
413
317
|
# extract the non-box equality constraints on actions
|
|
@@ -419,15 +323,16 @@ class JaxRDDLCompiler:
|
|
|
419
323
|
# compile them to JAX and write as g(s, a) == 0
|
|
420
324
|
jax_equalities = []
|
|
421
325
|
for (left, right) in equalities:
|
|
422
|
-
jax_lhs = self._jax(left,
|
|
423
|
-
jax_rhs = self._jax(right,
|
|
424
|
-
jax_constr = self._jax_binary(jax_lhs, jax_rhs,
|
|
326
|
+
jax_lhs = self._jax(left, init_params)
|
|
327
|
+
jax_rhs = self._jax(right, init_params)
|
|
328
|
+
jax_constr = self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
|
|
425
329
|
jax_equalities.append(jax_constr)
|
|
426
330
|
|
|
427
331
|
return jax_inequalities, jax_equalities
|
|
428
332
|
|
|
429
333
|
def compile_transition(self, check_constraints: bool=False,
|
|
430
|
-
constraint_func: bool=False
|
|
334
|
+
constraint_func: bool=False,
|
|
335
|
+
init_params_constr: Dict[str, Any]={}) -> Callable:
|
|
431
336
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
432
337
|
samples the next state.
|
|
433
338
|
|
|
@@ -467,13 +372,12 @@ class JaxRDDLCompiler:
|
|
|
467
372
|
'''
|
|
468
373
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
469
374
|
rddl = self.rddl
|
|
470
|
-
reward_fn, cpfs
|
|
471
|
-
|
|
472
|
-
self.preconditions, self.invariants, self.terminations
|
|
375
|
+
reward_fn, cpfs, preconds, invariants, terminals = \
|
|
376
|
+
self.reward, self.cpfs, self.preconditions, self.invariants, self.terminations
|
|
473
377
|
|
|
474
378
|
# compile constraint information
|
|
475
379
|
if constraint_func:
|
|
476
|
-
inequality_fns, equality_fns = self._jax_nonlinear_constraints()
|
|
380
|
+
inequality_fns, equality_fns = self._jax_nonlinear_constraints(init_params_constr)
|
|
477
381
|
else:
|
|
478
382
|
inequality_fns, equality_fns = None, None
|
|
479
383
|
|
|
@@ -486,7 +390,7 @@ class JaxRDDLCompiler:
|
|
|
486
390
|
precond_check = True
|
|
487
391
|
if check_constraints:
|
|
488
392
|
for precond in preconds:
|
|
489
|
-
sample, key, err = precond(subs, model_params, key)
|
|
393
|
+
sample, key, err, model_params = precond(subs, model_params, key)
|
|
490
394
|
precond_check = jnp.logical_and(precond_check, sample)
|
|
491
395
|
errors |= err
|
|
492
396
|
|
|
@@ -494,21 +398,21 @@ class JaxRDDLCompiler:
|
|
|
494
398
|
inequalities, equalities = [], []
|
|
495
399
|
if constraint_func:
|
|
496
400
|
for constraint in inequality_fns:
|
|
497
|
-
sample, key, err = constraint(subs, model_params, key)
|
|
401
|
+
sample, key, err, model_params = constraint(subs, model_params, key)
|
|
498
402
|
inequalities.append(sample)
|
|
499
403
|
errors |= err
|
|
500
404
|
for constraint in equality_fns:
|
|
501
|
-
sample, key, err = constraint(subs, model_params, key)
|
|
405
|
+
sample, key, err, model_params = constraint(subs, model_params, key)
|
|
502
406
|
equalities.append(sample)
|
|
503
407
|
errors |= err
|
|
504
408
|
|
|
505
409
|
# calculate CPFs in topological order
|
|
506
410
|
for (name, cpf) in cpfs.items():
|
|
507
|
-
subs[name], key, err = cpf(subs, model_params, key)
|
|
411
|
+
subs[name], key, err, model_params = cpf(subs, model_params, key)
|
|
508
412
|
errors |= err
|
|
509
413
|
|
|
510
414
|
# calculate the immediate reward
|
|
511
|
-
reward, key, err = reward_fn(subs, model_params, key)
|
|
415
|
+
reward, key, err, model_params = reward_fn(subs, model_params, key)
|
|
512
416
|
errors |= err
|
|
513
417
|
|
|
514
418
|
# calculate fluent values
|
|
@@ -523,7 +427,7 @@ class JaxRDDLCompiler:
|
|
|
523
427
|
invariant_check = True
|
|
524
428
|
if check_constraints:
|
|
525
429
|
for invariant in invariants:
|
|
526
|
-
sample, key, err = invariant(subs, model_params, key)
|
|
430
|
+
sample, key, err, model_params = invariant(subs, model_params, key)
|
|
527
431
|
invariant_check = jnp.logical_and(invariant_check, sample)
|
|
528
432
|
errors |= err
|
|
529
433
|
|
|
@@ -531,7 +435,7 @@ class JaxRDDLCompiler:
|
|
|
531
435
|
terminated_check = False
|
|
532
436
|
if check_constraints:
|
|
533
437
|
for terminal in terminals:
|
|
534
|
-
sample, key, err = terminal(subs, model_params, key)
|
|
438
|
+
sample, key, err, model_params = terminal(subs, model_params, key)
|
|
535
439
|
terminated_check = jnp.logical_or(terminated_check, sample)
|
|
536
440
|
errors |= err
|
|
537
441
|
|
|
@@ -548,7 +452,7 @@ class JaxRDDLCompiler:
|
|
|
548
452
|
log['inequalities'] = inequalities
|
|
549
453
|
log['equalities'] = equalities
|
|
550
454
|
|
|
551
|
-
return subs, log
|
|
455
|
+
return subs, log, model_params
|
|
552
456
|
|
|
553
457
|
return _jax_wrapped_single_step
|
|
554
458
|
|
|
@@ -556,7 +460,8 @@ class JaxRDDLCompiler:
|
|
|
556
460
|
n_steps: int,
|
|
557
461
|
n_batch: int,
|
|
558
462
|
check_constraints: bool=False,
|
|
559
|
-
constraint_func: bool=False
|
|
463
|
+
constraint_func: bool=False,
|
|
464
|
+
init_params_constr: Dict[str, Any]={}) -> Callable:
|
|
560
465
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
561
466
|
samples trajectories with a fixed horizon from a policy.
|
|
562
467
|
|
|
@@ -569,7 +474,8 @@ class JaxRDDLCompiler:
|
|
|
569
474
|
|
|
570
475
|
The returned value of the returned function is:
|
|
571
476
|
- log is the dictionary of all trajectory information, including
|
|
572
|
-
constraints that were satisfied, errors, etc
|
|
477
|
+
constraints that were satisfied, errors, etc
|
|
478
|
+
- model_params is the final set of model parameters.
|
|
573
479
|
|
|
574
480
|
The arguments of the policy function is:
|
|
575
481
|
- key is the PRNG key (used by a stochastic policy)
|
|
@@ -589,7 +495,8 @@ class JaxRDDLCompiler:
|
|
|
589
495
|
in addition to the usual outputs
|
|
590
496
|
'''
|
|
591
497
|
rddl = self.rddl
|
|
592
|
-
jax_step_fn = self.compile_transition(
|
|
498
|
+
jax_step_fn = self.compile_transition(
|
|
499
|
+
check_constraints, constraint_func, init_params_constr)
|
|
593
500
|
|
|
594
501
|
# for POMDP only observ-fluents are assumed visible to the policy
|
|
595
502
|
if rddl.observ_fluents:
|
|
@@ -605,18 +512,19 @@ class JaxRDDLCompiler:
|
|
|
605
512
|
if var in observed_vars}
|
|
606
513
|
actions = policy(key, policy_params, hyperparams, step, states)
|
|
607
514
|
key, subkey = random.split(key)
|
|
608
|
-
|
|
609
|
-
return subs, log
|
|
515
|
+
return jax_step_fn(subkey, actions, subs, model_params)
|
|
610
516
|
|
|
611
517
|
# do a batched step update from the policy
|
|
518
|
+
# TODO: come up with a better way to reduce the model_param batch dim
|
|
612
519
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
613
520
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
614
521
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
615
522
|
keys = jnp.asarray(subkeys)
|
|
616
|
-
subs, log = jax.vmap(
|
|
523
|
+
subs, log, model_params = jax.vmap(
|
|
617
524
|
_jax_wrapped_single_step_policy,
|
|
618
525
|
in_axes=(0, None, None, None, 0, None)
|
|
619
526
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
527
|
+
model_params = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_params)
|
|
620
528
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
621
529
|
return carry, log
|
|
622
530
|
|
|
@@ -625,14 +533,15 @@ class JaxRDDLCompiler:
|
|
|
625
533
|
subs, model_params):
|
|
626
534
|
start = (key, policy_params, hyperparams, subs, model_params)
|
|
627
535
|
steps = jnp.arange(n_steps)
|
|
628
|
-
|
|
536
|
+
end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
|
|
629
537
|
log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
630
|
-
|
|
538
|
+
model_params = end[-1]
|
|
539
|
+
return log, model_params
|
|
631
540
|
|
|
632
541
|
return _jax_wrapped_batched_rollout
|
|
633
542
|
|
|
634
543
|
# ===========================================================================
|
|
635
|
-
# error checks
|
|
544
|
+
# error checks and prints
|
|
636
545
|
# ===========================================================================
|
|
637
546
|
|
|
638
547
|
def print_jax(self) -> Dict[str, Any]:
|
|
@@ -640,19 +549,30 @@ class JaxRDDLCompiler:
|
|
|
640
549
|
Jax compiled expressions from the RDDL file.
|
|
641
550
|
'''
|
|
642
551
|
subs = self.init_values
|
|
643
|
-
|
|
552
|
+
init_params = self.model_params
|
|
644
553
|
key = jax.random.PRNGKey(42)
|
|
645
|
-
printed = {
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
554
|
+
printed = {
|
|
555
|
+
'cpfs': {name: str(jax.make_jaxpr(expr)(subs, init_params, key))
|
|
556
|
+
for (name, expr) in self.cpfs.items()},
|
|
557
|
+
'reward': str(jax.make_jaxpr(self.reward)(subs, init_params, key)),
|
|
558
|
+
'invariants': [str(jax.make_jaxpr(expr)(subs, init_params, key))
|
|
559
|
+
for expr in self.invariants],
|
|
560
|
+
'preconditions': [str(jax.make_jaxpr(expr)(subs, init_params, key))
|
|
561
|
+
for expr in self.preconditions],
|
|
562
|
+
'terminations': [str(jax.make_jaxpr(expr)(subs, init_params, key))
|
|
563
|
+
for expr in self.terminations]
|
|
564
|
+
}
|
|
655
565
|
return printed
|
|
566
|
+
|
|
567
|
+
def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
|
|
568
|
+
'''Returns a dictionary of additional information about model
|
|
569
|
+
parameters.'''
|
|
570
|
+
result = {}
|
|
571
|
+
for (id, value) in self.model_params.items():
|
|
572
|
+
expr_id = int(str(id).split('_')[0])
|
|
573
|
+
expr = self.traced.lookup(expr_id)
|
|
574
|
+
result[id] = {'id': expr_id, 'rddl_op': ' '.join(expr.etype), 'init_value': value}
|
|
575
|
+
return result
|
|
656
576
|
|
|
657
577
|
@staticmethod
|
|
658
578
|
def _check_valid_op(expr, valid_ops):
|
|
@@ -661,8 +581,7 @@ class JaxRDDLCompiler:
|
|
|
661
581
|
valid_op_str = ','.join(valid_ops.keys())
|
|
662
582
|
raise RDDLNotImplementedError(
|
|
663
583
|
f'{etype} operator {op} is not supported: '
|
|
664
|
-
f'must be in {valid_op_str}.\n' +
|
|
665
|
-
print_stack_trace(expr))
|
|
584
|
+
f'must be in {valid_op_str}.\n' + print_stack_trace(expr))
|
|
666
585
|
|
|
667
586
|
@staticmethod
|
|
668
587
|
def _check_num_args(expr, required_args):
|
|
@@ -671,8 +590,7 @@ class JaxRDDLCompiler:
|
|
|
671
590
|
etype, op = expr.etype
|
|
672
591
|
raise RDDLInvalidNumberOfArgumentsError(
|
|
673
592
|
f'{etype} operator {op} requires {required_args} arguments, '
|
|
674
|
-
f'got {actual_args}.\n' +
|
|
675
|
-
print_stack_trace(expr))
|
|
593
|
+
f'got {actual_args}.\n' + print_stack_trace(expr))
|
|
676
594
|
|
|
677
595
|
ERROR_CODES = {
|
|
678
596
|
'NORMAL': 0,
|
|
@@ -749,73 +667,34 @@ class JaxRDDLCompiler:
|
|
|
749
667
|
messages = [JaxRDDLCompiler.INVERSE_ERROR_CODES[i] for i in codes]
|
|
750
668
|
return messages
|
|
751
669
|
|
|
752
|
-
# ===========================================================================
|
|
753
|
-
# handling of auxiliary data (e.g. model tuning parameters)
|
|
754
|
-
# ===========================================================================
|
|
755
|
-
|
|
756
|
-
def _unwrap(self, op, expr_id, info):
|
|
757
|
-
jax_op, name = op, None
|
|
758
|
-
model_params, relaxed_list = info
|
|
759
|
-
if isinstance(op, tuple):
|
|
760
|
-
jax_op, param = op
|
|
761
|
-
if param is not None:
|
|
762
|
-
tags, values = param
|
|
763
|
-
sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
|
|
764
|
-
if isinstance(tags, tuple):
|
|
765
|
-
name = sep.join(tags)
|
|
766
|
-
else:
|
|
767
|
-
name = str(tags)
|
|
768
|
-
name = f'{name}{sep}{expr_id}'
|
|
769
|
-
if name in model_params:
|
|
770
|
-
raise RuntimeError(
|
|
771
|
-
f'Internal error: model parameter {name} is already defined.')
|
|
772
|
-
model_params[name] = (values, tags, expr_id, jax_op.__name__)
|
|
773
|
-
relaxed_list.append((param, expr_id, jax_op.__name__))
|
|
774
|
-
return jax_op, name
|
|
775
|
-
|
|
776
|
-
def summarize_model_relaxations(self) -> str:
|
|
777
|
-
'''Returns a string of information about model relaxations in the
|
|
778
|
-
compiled model.'''
|
|
779
|
-
occurence_by_type = {}
|
|
780
|
-
for (_, expr_id, jax_op) in self.relaxations:
|
|
781
|
-
etype = self.traced.lookup(expr_id).etype
|
|
782
|
-
source = f'{etype[1]} ({etype[0]})'
|
|
783
|
-
sub = f'{source:<30} --> {jax_op}'
|
|
784
|
-
occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
|
|
785
|
-
col = "{:<80} {:<10}\n"
|
|
786
|
-
table = col.format('Substitution', 'Count')
|
|
787
|
-
for (sub, occurs) in occurence_by_type.items():
|
|
788
|
-
table += col.format(sub, occurs)
|
|
789
|
-
return table
|
|
790
|
-
|
|
791
670
|
# ===========================================================================
|
|
792
671
|
# expression compilation
|
|
793
672
|
# ===========================================================================
|
|
794
673
|
|
|
795
|
-
def _jax(self, expr,
|
|
674
|
+
def _jax(self, expr, init_params, dtype=None):
|
|
796
675
|
etype, _ = expr.etype
|
|
797
676
|
if etype == 'constant':
|
|
798
|
-
jax_expr = self._jax_constant(expr,
|
|
677
|
+
jax_expr = self._jax_constant(expr, init_params)
|
|
799
678
|
elif etype == 'pvar':
|
|
800
|
-
jax_expr = self._jax_pvar(expr,
|
|
679
|
+
jax_expr = self._jax_pvar(expr, init_params)
|
|
801
680
|
elif etype == 'arithmetic':
|
|
802
|
-
jax_expr = self._jax_arithmetic(expr,
|
|
681
|
+
jax_expr = self._jax_arithmetic(expr, init_params)
|
|
803
682
|
elif etype == 'relational':
|
|
804
|
-
jax_expr = self._jax_relational(expr,
|
|
683
|
+
jax_expr = self._jax_relational(expr, init_params)
|
|
805
684
|
elif etype == 'boolean':
|
|
806
|
-
jax_expr = self._jax_logical(expr,
|
|
685
|
+
jax_expr = self._jax_logical(expr, init_params)
|
|
807
686
|
elif etype == 'aggregation':
|
|
808
|
-
jax_expr = self._jax_aggregation(expr,
|
|
687
|
+
jax_expr = self._jax_aggregation(expr, init_params)
|
|
809
688
|
elif etype == 'func':
|
|
810
|
-
jax_expr = self._jax_functional(expr,
|
|
689
|
+
jax_expr = self._jax_functional(expr, init_params)
|
|
811
690
|
elif etype == 'control':
|
|
812
|
-
jax_expr = self._jax_control(expr,
|
|
691
|
+
jax_expr = self._jax_control(expr, init_params)
|
|
813
692
|
elif etype == 'randomvar':
|
|
814
|
-
jax_expr = self._jax_random(expr,
|
|
693
|
+
jax_expr = self._jax_random(expr, init_params)
|
|
815
694
|
elif etype == 'randomvector':
|
|
816
|
-
jax_expr = self._jax_random_vector(expr,
|
|
695
|
+
jax_expr = self._jax_random_vector(expr, init_params)
|
|
817
696
|
elif etype == 'matrix':
|
|
818
|
-
jax_expr = self._jax_matrix(expr,
|
|
697
|
+
jax_expr = self._jax_matrix(expr, init_params)
|
|
819
698
|
else:
|
|
820
699
|
raise RDDLNotImplementedError(
|
|
821
700
|
f'Internal error: expression type {expr} is not supported.\n' +
|
|
@@ -831,13 +710,14 @@ class JaxRDDLCompiler:
|
|
|
831
710
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
832
711
|
|
|
833
712
|
def _jax_wrapped_cast(x, params, key):
|
|
834
|
-
val, key, err = jax_expr(x, params, key)
|
|
713
|
+
val, key, err, params = jax_expr(x, params, key)
|
|
835
714
|
sample = jnp.asarray(val, dtype=dtype)
|
|
836
715
|
invalid_cast = jnp.logical_and(
|
|
837
716
|
jnp.logical_not(jnp.can_cast(val, dtype)),
|
|
838
|
-
jnp.any(sample != val)
|
|
717
|
+
jnp.any(sample != val)
|
|
718
|
+
)
|
|
839
719
|
err |= (invalid_cast * ERR)
|
|
840
|
-
return sample, key, err
|
|
720
|
+
return sample, key, err, params
|
|
841
721
|
|
|
842
722
|
return _jax_wrapped_cast
|
|
843
723
|
|
|
@@ -856,13 +736,13 @@ class JaxRDDLCompiler:
|
|
|
856
736
|
# leaves
|
|
857
737
|
# ===========================================================================
|
|
858
738
|
|
|
859
|
-
def _jax_constant(self, expr,
|
|
739
|
+
def _jax_constant(self, expr, init_params):
|
|
860
740
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
861
741
|
cached_value = self.traced.cached_sim_info(expr)
|
|
862
742
|
|
|
863
743
|
def _jax_wrapped_constant(x, params, key):
|
|
864
744
|
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
865
|
-
return sample, key, NORMAL
|
|
745
|
+
return sample, key, NORMAL, params
|
|
866
746
|
|
|
867
747
|
return _jax_wrapped_constant
|
|
868
748
|
|
|
@@ -870,11 +750,11 @@ class JaxRDDLCompiler:
|
|
|
870
750
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
871
751
|
|
|
872
752
|
def _jax_wrapped_pvar_slice(x, params, key):
|
|
873
|
-
return _slice, key, NORMAL
|
|
753
|
+
return _slice, key, NORMAL, params
|
|
874
754
|
|
|
875
755
|
return _jax_wrapped_pvar_slice
|
|
876
756
|
|
|
877
|
-
def _jax_pvar(self, expr,
|
|
757
|
+
def _jax_pvar(self, expr, init_params):
|
|
878
758
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
879
759
|
var, pvars = expr.args
|
|
880
760
|
is_value, cached_info = self.traced.cached_sim_info(expr)
|
|
@@ -886,7 +766,7 @@ class JaxRDDLCompiler:
|
|
|
886
766
|
|
|
887
767
|
def _jax_wrapped_object(x, params, key):
|
|
888
768
|
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
889
|
-
return sample, key, NORMAL
|
|
769
|
+
return sample, key, NORMAL, params
|
|
890
770
|
|
|
891
771
|
return _jax_wrapped_object
|
|
892
772
|
|
|
@@ -896,7 +776,7 @@ class JaxRDDLCompiler:
|
|
|
896
776
|
def _jax_wrapped_pvar_scalar(x, params, key):
|
|
897
777
|
value = x[var]
|
|
898
778
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
899
|
-
return sample, key, NORMAL
|
|
779
|
+
return sample, key, NORMAL, params
|
|
900
780
|
|
|
901
781
|
return _jax_wrapped_pvar_scalar
|
|
902
782
|
|
|
@@ -907,7 +787,7 @@ class JaxRDDLCompiler:
|
|
|
907
787
|
# compile nested expressions
|
|
908
788
|
if slices and op_code == RDDLObjectsTracer.NUMPY_OP_CODE.NESTED_SLICE:
|
|
909
789
|
|
|
910
|
-
jax_nested_expr = [(self._jax(arg,
|
|
790
|
+
jax_nested_expr = [(self._jax(arg, init_params)
|
|
911
791
|
if _slice is None
|
|
912
792
|
else self._jax_pvar_slice(_slice))
|
|
913
793
|
for (arg, _slice) in zip(pvars, slices)]
|
|
@@ -918,11 +798,11 @@ class JaxRDDLCompiler:
|
|
|
918
798
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
919
799
|
new_slices = [None] * len(jax_nested_expr)
|
|
920
800
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
921
|
-
new_slices[i], key, err = jax_expr(x, params, key)
|
|
801
|
+
new_slices[i], key, err, params = jax_expr(x, params, key)
|
|
922
802
|
error |= err
|
|
923
803
|
new_slices = tuple(new_slices)
|
|
924
804
|
sample = sample[new_slices]
|
|
925
|
-
return sample, key, error
|
|
805
|
+
return sample, key, error, params
|
|
926
806
|
|
|
927
807
|
return _jax_wrapped_pvar_tensor_nested
|
|
928
808
|
|
|
@@ -941,7 +821,7 @@ class JaxRDDLCompiler:
|
|
|
941
821
|
sample = jnp.einsum(sample, *op_args)
|
|
942
822
|
elif op_code == RDDLObjectsTracer.NUMPY_OP_CODE.TRANSPOSE:
|
|
943
823
|
sample = jnp.transpose(sample, axes=op_args)
|
|
944
|
-
return sample, key, NORMAL
|
|
824
|
+
return sample, key, NORMAL, params
|
|
945
825
|
|
|
946
826
|
return _jax_wrapped_pvar_tensor_non_nested
|
|
947
827
|
|
|
@@ -949,46 +829,43 @@ class JaxRDDLCompiler:
|
|
|
949
829
|
# mathematical
|
|
950
830
|
# ===========================================================================
|
|
951
831
|
|
|
952
|
-
def _jax_unary(self, jax_expr, jax_op,
|
|
953
|
-
at_least_int=False, check_dtype=None):
|
|
832
|
+
def _jax_unary(self, jax_expr, jax_op, at_least_int=False, check_dtype=None):
|
|
954
833
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
955
834
|
|
|
956
835
|
def _jax_wrapped_unary_op(x, params, key):
|
|
957
|
-
sample, key, err = jax_expr(x, params, key)
|
|
836
|
+
sample, key, err, params = jax_expr(x, params, key)
|
|
958
837
|
if at_least_int:
|
|
959
838
|
sample = self.ONE * sample
|
|
960
|
-
|
|
961
|
-
sample = jax_op(sample, param)
|
|
839
|
+
sample, params = jax_op(sample, params)
|
|
962
840
|
if check_dtype is not None:
|
|
963
841
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample, check_dtype))
|
|
964
842
|
err |= (invalid_cast * ERR)
|
|
965
|
-
return sample, key, err
|
|
843
|
+
return sample, key, err, params
|
|
966
844
|
|
|
967
845
|
return _jax_wrapped_unary_op
|
|
968
846
|
|
|
969
|
-
def _jax_binary(self, jax_lhs, jax_rhs, jax_op,
|
|
970
|
-
at_least_int=False, check_dtype=None):
|
|
847
|
+
def _jax_binary(self, jax_lhs, jax_rhs, jax_op, at_least_int=False, check_dtype=None):
|
|
971
848
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
972
849
|
|
|
973
850
|
def _jax_wrapped_binary_op(x, params, key):
|
|
974
|
-
sample1, key, err1 = jax_lhs(x, params, key)
|
|
975
|
-
sample2, key, err2 = jax_rhs(x, params, key)
|
|
851
|
+
sample1, key, err1, params = jax_lhs(x, params, key)
|
|
852
|
+
sample2, key, err2, params = jax_rhs(x, params, key)
|
|
976
853
|
if at_least_int:
|
|
977
854
|
sample1 = self.ONE * sample1
|
|
978
855
|
sample2 = self.ONE * sample2
|
|
979
|
-
|
|
980
|
-
sample = jax_op(sample1, sample2, param)
|
|
856
|
+
sample, params = jax_op(sample1, sample2, params)
|
|
981
857
|
err = err1 | err2
|
|
982
858
|
if check_dtype is not None:
|
|
983
859
|
invalid_cast = jnp.logical_not(jnp.logical_and(
|
|
984
860
|
jnp.can_cast(sample1, check_dtype),
|
|
985
|
-
jnp.can_cast(sample2, check_dtype))
|
|
861
|
+
jnp.can_cast(sample2, check_dtype))
|
|
862
|
+
)
|
|
986
863
|
err |= (invalid_cast * ERR)
|
|
987
|
-
return sample, key, err
|
|
864
|
+
return sample, key, err, params
|
|
988
865
|
|
|
989
866
|
return _jax_wrapped_binary_op
|
|
990
867
|
|
|
991
|
-
def _jax_arithmetic(self, expr,
|
|
868
|
+
def _jax_arithmetic(self, expr, init_params):
|
|
992
869
|
_, op = expr.etype
|
|
993
870
|
|
|
994
871
|
# if expression is non-fluent, always use the exact operation
|
|
@@ -1005,22 +882,21 @@ class JaxRDDLCompiler:
|
|
|
1005
882
|
n = len(args)
|
|
1006
883
|
if n == 1 and op == '-':
|
|
1007
884
|
arg, = args
|
|
1008
|
-
jax_expr = self._jax(arg,
|
|
1009
|
-
jax_op
|
|
1010
|
-
return self._jax_unary(jax_expr, jax_op,
|
|
885
|
+
jax_expr = self._jax(arg, init_params)
|
|
886
|
+
jax_op = negative_op(expr.id, init_params)
|
|
887
|
+
return self._jax_unary(jax_expr, jax_op, at_least_int=True)
|
|
1011
888
|
|
|
1012
889
|
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
1013
|
-
jax_exprs = [self._jax(arg,
|
|
1014
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
890
|
+
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
1015
891
|
result = jax_exprs[0]
|
|
1016
|
-
for jax_rhs in jax_exprs[1:]:
|
|
1017
|
-
|
|
1018
|
-
|
|
892
|
+
for i, jax_rhs in enumerate(jax_exprs[1:]):
|
|
893
|
+
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
894
|
+
result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
|
|
1019
895
|
return result
|
|
1020
896
|
|
|
1021
897
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1022
898
|
|
|
1023
|
-
def _jax_relational(self, expr,
|
|
899
|
+
def _jax_relational(self, expr, init_params):
|
|
1024
900
|
_, op = expr.etype
|
|
1025
901
|
|
|
1026
902
|
# if expression is non-fluent, always use the exact operation
|
|
@@ -1029,17 +905,16 @@ class JaxRDDLCompiler:
|
|
|
1029
905
|
else:
|
|
1030
906
|
valid_ops = self.RELATIONAL_OPS
|
|
1031
907
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
1032
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1033
908
|
|
|
1034
909
|
# recursively compile arguments
|
|
1035
910
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1036
911
|
lhs, rhs = expr.args
|
|
1037
|
-
jax_lhs = self._jax(lhs,
|
|
1038
|
-
jax_rhs = self._jax(rhs,
|
|
1039
|
-
|
|
1040
|
-
|
|
912
|
+
jax_lhs = self._jax(lhs, init_params)
|
|
913
|
+
jax_rhs = self._jax(rhs, init_params)
|
|
914
|
+
jax_op = valid_ops[op](expr.id, init_params)
|
|
915
|
+
return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
|
|
1041
916
|
|
|
1042
|
-
def _jax_logical(self, expr,
|
|
917
|
+
def _jax_logical(self, expr, init_params):
|
|
1043
918
|
_, op = expr.etype
|
|
1044
919
|
|
|
1045
920
|
# if expression is non-fluent, always use the exact operation
|
|
@@ -1056,22 +931,21 @@ class JaxRDDLCompiler:
|
|
|
1056
931
|
n = len(args)
|
|
1057
932
|
if n == 1 and op == '~':
|
|
1058
933
|
arg, = args
|
|
1059
|
-
jax_expr = self._jax(arg,
|
|
1060
|
-
jax_op
|
|
1061
|
-
return self._jax_unary(jax_expr, jax_op,
|
|
934
|
+
jax_expr = self._jax(arg, init_params)
|
|
935
|
+
jax_op = logical_not_op(expr.id, init_params)
|
|
936
|
+
return self._jax_unary(jax_expr, jax_op, check_dtype=bool)
|
|
1062
937
|
|
|
1063
938
|
elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
|
|
1064
|
-
jax_exprs = [self._jax(arg,
|
|
1065
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
939
|
+
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
1066
940
|
result = jax_exprs[0]
|
|
1067
|
-
for jax_rhs in jax_exprs[1:]:
|
|
1068
|
-
|
|
1069
|
-
|
|
941
|
+
for i, jax_rhs in enumerate(jax_exprs[1:]):
|
|
942
|
+
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
943
|
+
result = self._jax_binary(result, jax_rhs, jax_op, check_dtype=bool)
|
|
1070
944
|
return result
|
|
1071
945
|
|
|
1072
946
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1073
947
|
|
|
1074
|
-
def _jax_aggregation(self, expr,
|
|
948
|
+
def _jax_aggregation(self, expr, init_params):
|
|
1075
949
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
1076
950
|
_, op = expr.etype
|
|
1077
951
|
|
|
@@ -1081,28 +955,27 @@ class JaxRDDLCompiler:
|
|
|
1081
955
|
else:
|
|
1082
956
|
valid_ops = self.AGGREGATION_OPS
|
|
1083
957
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
1084
|
-
|
|
958
|
+
is_floating = op not in self.AGGREGATION_BOOL
|
|
1085
959
|
|
|
1086
960
|
# recursively compile arguments
|
|
1087
|
-
is_floating = op not in self.AGGREGATION_BOOL
|
|
1088
961
|
* _, arg = expr.args
|
|
1089
962
|
_, axes = self.traced.cached_sim_info(expr)
|
|
1090
|
-
jax_expr = self._jax(arg,
|
|
963
|
+
jax_expr = self._jax(arg, init_params)
|
|
964
|
+
jax_op = valid_ops[op](expr.id, init_params)
|
|
1091
965
|
|
|
1092
966
|
def _jax_wrapped_aggregation(x, params, key):
|
|
1093
|
-
sample, key, err = jax_expr(x, params, key)
|
|
967
|
+
sample, key, err, params = jax_expr(x, params, key)
|
|
1094
968
|
if is_floating:
|
|
1095
969
|
sample = self.ONE * sample
|
|
1096
970
|
else:
|
|
1097
971
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample, bool))
|
|
1098
972
|
err |= (invalid_cast * ERR)
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
return sample, key, err
|
|
973
|
+
sample, params = jax_op(sample, axis=axes, params=params)
|
|
974
|
+
return sample, key, err, params
|
|
1102
975
|
|
|
1103
976
|
return _jax_wrapped_aggregation
|
|
1104
977
|
|
|
1105
|
-
def _jax_functional(self, expr,
|
|
978
|
+
def _jax_functional(self, expr, init_params):
|
|
1106
979
|
_, op = expr.etype
|
|
1107
980
|
|
|
1108
981
|
# if expression is non-fluent, always use the exact operation
|
|
@@ -1117,39 +990,36 @@ class JaxRDDLCompiler:
|
|
|
1117
990
|
if op in unary_ops:
|
|
1118
991
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1119
992
|
arg, = expr.args
|
|
1120
|
-
jax_expr = self._jax(arg,
|
|
1121
|
-
jax_op
|
|
1122
|
-
return self._jax_unary(jax_expr, jax_op,
|
|
993
|
+
jax_expr = self._jax(arg, init_params)
|
|
994
|
+
jax_op = unary_ops[op](expr.id, init_params)
|
|
995
|
+
return self._jax_unary(jax_expr, jax_op, at_least_int=True)
|
|
1123
996
|
|
|
1124
997
|
elif op in binary_ops:
|
|
1125
998
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1126
999
|
lhs, rhs = expr.args
|
|
1127
|
-
jax_lhs = self._jax(lhs,
|
|
1128
|
-
jax_rhs = self._jax(rhs,
|
|
1129
|
-
jax_op
|
|
1130
|
-
return self._jax_binary(
|
|
1131
|
-
jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
1000
|
+
jax_lhs = self._jax(lhs, init_params)
|
|
1001
|
+
jax_rhs = self._jax(rhs, init_params)
|
|
1002
|
+
jax_op = binary_ops[op](expr.id, init_params)
|
|
1003
|
+
return self._jax_binary(jax_lhs, jax_rhs, jax_op, at_least_int=True)
|
|
1132
1004
|
|
|
1133
1005
|
raise RDDLNotImplementedError(
|
|
1134
|
-
f'Function {op} is not supported.\n' +
|
|
1135
|
-
print_stack_trace(expr))
|
|
1006
|
+
f'Function {op} is not supported.\n' + print_stack_trace(expr))
|
|
1136
1007
|
|
|
1137
1008
|
# ===========================================================================
|
|
1138
1009
|
# control flow
|
|
1139
1010
|
# ===========================================================================
|
|
1140
1011
|
|
|
1141
|
-
def _jax_control(self, expr,
|
|
1012
|
+
def _jax_control(self, expr, init_params):
|
|
1142
1013
|
_, op = expr.etype
|
|
1143
1014
|
if op == 'if':
|
|
1144
|
-
return self._jax_if(expr,
|
|
1015
|
+
return self._jax_if(expr, init_params)
|
|
1145
1016
|
elif op == 'switch':
|
|
1146
|
-
return self._jax_switch(expr,
|
|
1017
|
+
return self._jax_switch(expr, init_params)
|
|
1147
1018
|
|
|
1148
1019
|
raise RDDLNotImplementedError(
|
|
1149
|
-
f'Control operator {op} is not supported.\n' +
|
|
1150
|
-
print_stack_trace(expr))
|
|
1020
|
+
f'Control operator {op} is not supported.\n' + print_stack_trace(expr))
|
|
1151
1021
|
|
|
1152
|
-
def _jax_if(self, expr,
|
|
1022
|
+
def _jax_if(self, expr, init_params):
|
|
1153
1023
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
1154
1024
|
JaxRDDLCompiler._check_num_args(expr, 3)
|
|
1155
1025
|
pred, if_true, if_false = expr.args
|
|
@@ -1159,27 +1029,26 @@ class JaxRDDLCompiler:
|
|
|
1159
1029
|
if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
|
|
1160
1030
|
else:
|
|
1161
1031
|
if_op = self.IF_HELPER
|
|
1162
|
-
|
|
1032
|
+
jax_op = if_op(expr.id, init_params)
|
|
1163
1033
|
|
|
1164
1034
|
# recursively compile arguments
|
|
1165
|
-
jax_pred = self._jax(pred,
|
|
1166
|
-
jax_true = self._jax(if_true,
|
|
1167
|
-
jax_false = self._jax(if_false,
|
|
1035
|
+
jax_pred = self._jax(pred, init_params)
|
|
1036
|
+
jax_true = self._jax(if_true, init_params)
|
|
1037
|
+
jax_false = self._jax(if_false, init_params)
|
|
1168
1038
|
|
|
1169
1039
|
def _jax_wrapped_if_then_else(x, params, key):
|
|
1170
|
-
sample1, key, err1 = jax_pred(x, params, key)
|
|
1171
|
-
sample2, key, err2 = jax_true(x, params, key)
|
|
1172
|
-
sample3, key, err3 = jax_false(x, params, key)
|
|
1173
|
-
|
|
1174
|
-
sample = jax_if(sample1, sample2, sample3, param)
|
|
1040
|
+
sample1, key, err1, params = jax_pred(x, params, key)
|
|
1041
|
+
sample2, key, err2, params = jax_true(x, params, key)
|
|
1042
|
+
sample3, key, err3, params = jax_false(x, params, key)
|
|
1043
|
+
sample, params = jax_op(sample1, sample2, sample3, params)
|
|
1175
1044
|
err = err1 | err2 | err3
|
|
1176
1045
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample1, bool))
|
|
1177
1046
|
err |= (invalid_cast * ERR)
|
|
1178
|
-
return sample, key, err
|
|
1047
|
+
return sample, key, err, params
|
|
1179
1048
|
|
|
1180
1049
|
return _jax_wrapped_if_then_else
|
|
1181
1050
|
|
|
1182
|
-
def _jax_switch(self, expr,
|
|
1051
|
+
def _jax_switch(self, expr, init_params):
|
|
1183
1052
|
pred, *_ = expr.args
|
|
1184
1053
|
|
|
1185
1054
|
# if predicate is non-fluent, always use the exact operation
|
|
@@ -1188,34 +1057,33 @@ class JaxRDDLCompiler:
|
|
|
1188
1057
|
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1189
1058
|
else:
|
|
1190
1059
|
switch_op = self.SWITCH_HELPER
|
|
1191
|
-
|
|
1060
|
+
jax_op = switch_op(expr.id, init_params)
|
|
1192
1061
|
|
|
1193
1062
|
# recursively compile predicate
|
|
1194
|
-
jax_pred = self._jax(pred,
|
|
1063
|
+
jax_pred = self._jax(pred, init_params)
|
|
1195
1064
|
|
|
1196
1065
|
# recursively compile cases
|
|
1197
1066
|
cases, default = self.traced.cached_sim_info(expr)
|
|
1198
|
-
jax_default = None if default is None else self._jax(default,
|
|
1199
|
-
jax_cases = [(jax_default if _case is None else self._jax(_case,
|
|
1067
|
+
jax_default = None if default is None else self._jax(default, init_params)
|
|
1068
|
+
jax_cases = [(jax_default if _case is None else self._jax(_case, init_params))
|
|
1200
1069
|
for _case in cases]
|
|
1201
1070
|
|
|
1202
1071
|
def _jax_wrapped_switch(x, params, key):
|
|
1203
1072
|
|
|
1204
1073
|
# sample predicate
|
|
1205
|
-
sample_pred, key, err = jax_pred(x, params, key)
|
|
1074
|
+
sample_pred, key, err, params = jax_pred(x, params, key)
|
|
1206
1075
|
|
|
1207
1076
|
# sample cases
|
|
1208
1077
|
sample_cases = [None] * len(jax_cases)
|
|
1209
1078
|
for (i, jax_case) in enumerate(jax_cases):
|
|
1210
|
-
sample_cases[i], key, err_case = jax_case(x, params, key)
|
|
1079
|
+
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
1211
1080
|
err |= err_case
|
|
1212
1081
|
sample_cases = jnp.asarray(
|
|
1213
1082
|
sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
1214
1083
|
|
|
1215
1084
|
# predicate (enum) is an integer - use it to extract from case array
|
|
1216
|
-
|
|
1217
|
-
sample
|
|
1218
|
-
return sample, key, err
|
|
1085
|
+
sample, params = jax_op(sample_pred, sample_cases, params)
|
|
1086
|
+
return sample, key, err, params
|
|
1219
1087
|
|
|
1220
1088
|
return _jax_wrapped_switch
|
|
1221
1089
|
|
|
@@ -1251,169 +1119,169 @@ class JaxRDDLCompiler:
|
|
|
1251
1119
|
# Geometric: (implement safe floor)
|
|
1252
1120
|
# Student: (no reparameterization)
|
|
1253
1121
|
|
|
1254
|
-
def _jax_random(self, expr,
|
|
1122
|
+
def _jax_random(self, expr, init_params):
|
|
1255
1123
|
_, name = expr.etype
|
|
1256
1124
|
if name == 'KronDelta':
|
|
1257
|
-
return self._jax_kron(expr,
|
|
1125
|
+
return self._jax_kron(expr, init_params)
|
|
1258
1126
|
elif name == 'DiracDelta':
|
|
1259
|
-
return self._jax_dirac(expr,
|
|
1127
|
+
return self._jax_dirac(expr, init_params)
|
|
1260
1128
|
elif name == 'Uniform':
|
|
1261
|
-
return self._jax_uniform(expr,
|
|
1129
|
+
return self._jax_uniform(expr, init_params)
|
|
1262
1130
|
elif name == 'Bernoulli':
|
|
1263
|
-
return self._jax_bernoulli(expr,
|
|
1131
|
+
return self._jax_bernoulli(expr, init_params)
|
|
1264
1132
|
elif name == 'Normal':
|
|
1265
|
-
return self._jax_normal(expr,
|
|
1133
|
+
return self._jax_normal(expr, init_params)
|
|
1266
1134
|
elif name == 'Poisson':
|
|
1267
|
-
return self._jax_poisson(expr,
|
|
1135
|
+
return self._jax_poisson(expr, init_params)
|
|
1268
1136
|
elif name == 'Exponential':
|
|
1269
|
-
return self._jax_exponential(expr,
|
|
1137
|
+
return self._jax_exponential(expr, init_params)
|
|
1270
1138
|
elif name == 'Weibull':
|
|
1271
|
-
return self._jax_weibull(expr,
|
|
1139
|
+
return self._jax_weibull(expr, init_params)
|
|
1272
1140
|
elif name == 'Gamma':
|
|
1273
|
-
return self._jax_gamma(expr,
|
|
1141
|
+
return self._jax_gamma(expr, init_params)
|
|
1274
1142
|
elif name == 'Binomial':
|
|
1275
|
-
return self._jax_binomial(expr,
|
|
1143
|
+
return self._jax_binomial(expr, init_params)
|
|
1276
1144
|
elif name == 'NegativeBinomial':
|
|
1277
|
-
return self._jax_negative_binomial(expr,
|
|
1145
|
+
return self._jax_negative_binomial(expr, init_params)
|
|
1278
1146
|
elif name == 'Beta':
|
|
1279
|
-
return self._jax_beta(expr,
|
|
1147
|
+
return self._jax_beta(expr, init_params)
|
|
1280
1148
|
elif name == 'Geometric':
|
|
1281
|
-
return self._jax_geometric(expr,
|
|
1149
|
+
return self._jax_geometric(expr, init_params)
|
|
1282
1150
|
elif name == 'Pareto':
|
|
1283
|
-
return self._jax_pareto(expr,
|
|
1151
|
+
return self._jax_pareto(expr, init_params)
|
|
1284
1152
|
elif name == 'Student':
|
|
1285
|
-
return self._jax_student(expr,
|
|
1153
|
+
return self._jax_student(expr, init_params)
|
|
1286
1154
|
elif name == 'Gumbel':
|
|
1287
|
-
return self._jax_gumbel(expr,
|
|
1155
|
+
return self._jax_gumbel(expr, init_params)
|
|
1288
1156
|
elif name == 'Laplace':
|
|
1289
|
-
return self._jax_laplace(expr,
|
|
1157
|
+
return self._jax_laplace(expr, init_params)
|
|
1290
1158
|
elif name == 'Cauchy':
|
|
1291
|
-
return self._jax_cauchy(expr,
|
|
1159
|
+
return self._jax_cauchy(expr, init_params)
|
|
1292
1160
|
elif name == 'Gompertz':
|
|
1293
|
-
return self._jax_gompertz(expr,
|
|
1161
|
+
return self._jax_gompertz(expr, init_params)
|
|
1294
1162
|
elif name == 'ChiSquare':
|
|
1295
|
-
return self._jax_chisquare(expr,
|
|
1163
|
+
return self._jax_chisquare(expr, init_params)
|
|
1296
1164
|
elif name == 'Kumaraswamy':
|
|
1297
|
-
return self._jax_kumaraswamy(expr,
|
|
1165
|
+
return self._jax_kumaraswamy(expr, init_params)
|
|
1298
1166
|
elif name == 'Discrete':
|
|
1299
|
-
return self._jax_discrete(expr,
|
|
1167
|
+
return self._jax_discrete(expr, init_params, unnorm=False)
|
|
1300
1168
|
elif name == 'UnnormDiscrete':
|
|
1301
|
-
return self._jax_discrete(expr,
|
|
1169
|
+
return self._jax_discrete(expr, init_params, unnorm=True)
|
|
1302
1170
|
elif name == 'Discrete(p)':
|
|
1303
|
-
return self._jax_discrete_pvar(expr,
|
|
1171
|
+
return self._jax_discrete_pvar(expr, init_params, unnorm=False)
|
|
1304
1172
|
elif name == 'UnnormDiscrete(p)':
|
|
1305
|
-
return self._jax_discrete_pvar(expr,
|
|
1173
|
+
return self._jax_discrete_pvar(expr, init_params, unnorm=True)
|
|
1306
1174
|
else:
|
|
1307
1175
|
raise RDDLNotImplementedError(
|
|
1308
1176
|
f'Distribution {name} is not supported.\n' +
|
|
1309
1177
|
print_stack_trace(expr))
|
|
1310
1178
|
|
|
1311
|
-
def _jax_kron(self, expr,
|
|
1179
|
+
def _jax_kron(self, expr, init_params):
|
|
1312
1180
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
|
|
1313
1181
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1314
1182
|
arg, = expr.args
|
|
1315
|
-
arg = self._jax(arg,
|
|
1183
|
+
arg = self._jax(arg, init_params)
|
|
1316
1184
|
|
|
1317
1185
|
# just check that the sample can be cast to int
|
|
1318
1186
|
def _jax_wrapped_distribution_kron(x, params, key):
|
|
1319
|
-
sample, key, err = arg(x, params, key)
|
|
1187
|
+
sample, key, err, params = arg(x, params, key)
|
|
1320
1188
|
invalid_cast = jnp.logical_not(jnp.can_cast(sample, self.INT))
|
|
1321
1189
|
err |= (invalid_cast * ERR)
|
|
1322
|
-
return sample, key, err
|
|
1190
|
+
return sample, key, err, params
|
|
1323
1191
|
|
|
1324
1192
|
return _jax_wrapped_distribution_kron
|
|
1325
1193
|
|
|
1326
|
-
def _jax_dirac(self, expr,
|
|
1194
|
+
def _jax_dirac(self, expr, init_params):
|
|
1327
1195
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1328
1196
|
arg, = expr.args
|
|
1329
|
-
arg = self._jax(arg,
|
|
1197
|
+
arg = self._jax(arg, init_params, dtype=self.REAL)
|
|
1330
1198
|
return arg
|
|
1331
1199
|
|
|
1332
|
-
def _jax_uniform(self, expr,
|
|
1200
|
+
def _jax_uniform(self, expr, init_params):
|
|
1333
1201
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_UNIFORM']
|
|
1334
1202
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1335
1203
|
|
|
1336
1204
|
arg_lb, arg_ub = expr.args
|
|
1337
|
-
jax_lb = self._jax(arg_lb,
|
|
1338
|
-
jax_ub = self._jax(arg_ub,
|
|
1205
|
+
jax_lb = self._jax(arg_lb, init_params)
|
|
1206
|
+
jax_ub = self._jax(arg_ub, init_params)
|
|
1339
1207
|
|
|
1340
1208
|
# reparameterization trick U(a, b) = a + (b - a) * U(0, 1)
|
|
1341
1209
|
def _jax_wrapped_distribution_uniform(x, params, key):
|
|
1342
|
-
lb, key, err1 = jax_lb(x, params, key)
|
|
1343
|
-
ub, key, err2 = jax_ub(x, params, key)
|
|
1210
|
+
lb, key, err1, params = jax_lb(x, params, key)
|
|
1211
|
+
ub, key, err2, params = jax_ub(x, params, key)
|
|
1344
1212
|
key, subkey = random.split(key)
|
|
1345
1213
|
U = random.uniform(key=subkey, shape=jnp.shape(lb), dtype=self.REAL)
|
|
1346
1214
|
sample = lb + (ub - lb) * U
|
|
1347
1215
|
out_of_bounds = jnp.logical_not(jnp.all(lb <= ub))
|
|
1348
1216
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1349
|
-
return sample, key, err
|
|
1217
|
+
return sample, key, err, params
|
|
1350
1218
|
|
|
1351
1219
|
return _jax_wrapped_distribution_uniform
|
|
1352
1220
|
|
|
1353
|
-
def _jax_normal(self, expr,
|
|
1221
|
+
def _jax_normal(self, expr, init_params):
|
|
1354
1222
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NORMAL']
|
|
1355
1223
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1356
1224
|
|
|
1357
1225
|
arg_mean, arg_var = expr.args
|
|
1358
|
-
jax_mean = self._jax(arg_mean,
|
|
1359
|
-
jax_var = self._jax(arg_var,
|
|
1226
|
+
jax_mean = self._jax(arg_mean, init_params)
|
|
1227
|
+
jax_var = self._jax(arg_var, init_params)
|
|
1360
1228
|
|
|
1361
1229
|
# reparameterization trick N(m, s^2) = m + s * N(0, 1)
|
|
1362
1230
|
def _jax_wrapped_distribution_normal(x, params, key):
|
|
1363
|
-
mean, key, err1 = jax_mean(x, params, key)
|
|
1364
|
-
var, key, err2 = jax_var(x, params, key)
|
|
1231
|
+
mean, key, err1, params = jax_mean(x, params, key)
|
|
1232
|
+
var, key, err2, params = jax_var(x, params, key)
|
|
1365
1233
|
std = jnp.sqrt(var)
|
|
1366
1234
|
key, subkey = random.split(key)
|
|
1367
1235
|
Z = random.normal(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1368
1236
|
sample = mean + std * Z
|
|
1369
1237
|
out_of_bounds = jnp.logical_not(jnp.all(var >= 0))
|
|
1370
1238
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1371
|
-
return sample, key, err
|
|
1239
|
+
return sample, key, err, params
|
|
1372
1240
|
|
|
1373
1241
|
return _jax_wrapped_distribution_normal
|
|
1374
1242
|
|
|
1375
|
-
def _jax_exponential(self, expr,
|
|
1243
|
+
def _jax_exponential(self, expr, init_params):
|
|
1376
1244
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_EXPONENTIAL']
|
|
1377
1245
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1378
1246
|
|
|
1379
1247
|
arg_scale, = expr.args
|
|
1380
|
-
jax_scale = self._jax(arg_scale,
|
|
1248
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1381
1249
|
|
|
1382
1250
|
# reparameterization trick Exp(s) = s * Exp(1)
|
|
1383
1251
|
def _jax_wrapped_distribution_exp(x, params, key):
|
|
1384
|
-
scale, key, err = jax_scale(x, params, key)
|
|
1252
|
+
scale, key, err, params = jax_scale(x, params, key)
|
|
1385
1253
|
key, subkey = random.split(key)
|
|
1386
1254
|
Exp1 = random.exponential(
|
|
1387
1255
|
key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1388
1256
|
sample = scale * Exp1
|
|
1389
1257
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1390
1258
|
err |= (out_of_bounds * ERR)
|
|
1391
|
-
return sample, key, err
|
|
1259
|
+
return sample, key, err, params
|
|
1392
1260
|
|
|
1393
1261
|
return _jax_wrapped_distribution_exp
|
|
1394
1262
|
|
|
1395
|
-
def _jax_weibull(self, expr,
|
|
1263
|
+
def _jax_weibull(self, expr, init_params):
|
|
1396
1264
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_WEIBULL']
|
|
1397
1265
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1398
1266
|
|
|
1399
1267
|
arg_shape, arg_scale = expr.args
|
|
1400
|
-
jax_shape = self._jax(arg_shape,
|
|
1401
|
-
jax_scale = self._jax(arg_scale,
|
|
1268
|
+
jax_shape = self._jax(arg_shape, init_params)
|
|
1269
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1402
1270
|
|
|
1403
1271
|
# reparameterization trick W(s, r) = r * (-ln(1 - U(0, 1))) ** (1 / s)
|
|
1404
1272
|
def _jax_wrapped_distribution_weibull(x, params, key):
|
|
1405
|
-
shape, key, err1 = jax_shape(x, params, key)
|
|
1406
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1273
|
+
shape, key, err1, params = jax_shape(x, params, key)
|
|
1274
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1407
1275
|
key, subkey = random.split(key)
|
|
1408
1276
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1409
1277
|
sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
|
|
1410
1278
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1411
1279
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1412
|
-
return sample, key, err
|
|
1280
|
+
return sample, key, err, params
|
|
1413
1281
|
|
|
1414
1282
|
return _jax_wrapped_distribution_weibull
|
|
1415
1283
|
|
|
1416
|
-
def _jax_bernoulli(self, expr,
|
|
1284
|
+
def _jax_bernoulli(self, expr, init_params):
|
|
1417
1285
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1418
1286
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1419
1287
|
arg_prob, = expr.args
|
|
@@ -1423,23 +1291,22 @@ class JaxRDDLCompiler:
|
|
|
1423
1291
|
bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
1424
1292
|
else:
|
|
1425
1293
|
bern_op = self.BERNOULLI_HELPER
|
|
1426
|
-
|
|
1294
|
+
jax_op = bern_op(expr.id, init_params)
|
|
1427
1295
|
|
|
1428
1296
|
# recursively compile arguments
|
|
1429
|
-
jax_prob = self._jax(arg_prob,
|
|
1297
|
+
jax_prob = self._jax(arg_prob, init_params)
|
|
1430
1298
|
|
|
1431
1299
|
def _jax_wrapped_distribution_bernoulli(x, params, key):
|
|
1432
|
-
prob, key, err = jax_prob(x, params, key)
|
|
1300
|
+
prob, key, err, params = jax_prob(x, params, key)
|
|
1433
1301
|
key, subkey = random.split(key)
|
|
1434
|
-
|
|
1435
|
-
sample = jax_bern(subkey, prob, param)
|
|
1302
|
+
sample, params = jax_op(subkey, prob, params)
|
|
1436
1303
|
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1437
1304
|
err |= (out_of_bounds * ERR)
|
|
1438
|
-
return sample, key, err
|
|
1305
|
+
return sample, key, err, params
|
|
1439
1306
|
|
|
1440
1307
|
return _jax_wrapped_distribution_bernoulli
|
|
1441
1308
|
|
|
1442
|
-
def _jax_poisson(self, expr,
|
|
1309
|
+
def _jax_poisson(self, expr, init_params):
|
|
1443
1310
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1444
1311
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1445
1312
|
arg_rate, = expr.args
|
|
@@ -1449,57 +1316,57 @@ class JaxRDDLCompiler:
|
|
|
1449
1316
|
poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
1450
1317
|
else:
|
|
1451
1318
|
poisson_op = self.POISSON_HELPER
|
|
1452
|
-
|
|
1319
|
+
jax_op = poisson_op(expr.id, init_params)
|
|
1453
1320
|
|
|
1454
1321
|
# recursively compile arguments
|
|
1455
|
-
jax_rate = self._jax(arg_rate,
|
|
1322
|
+
jax_rate = self._jax(arg_rate, init_params)
|
|
1456
1323
|
|
|
1457
1324
|
# uses the implicit JAX subroutine
|
|
1458
1325
|
def _jax_wrapped_distribution_poisson(x, params, key):
|
|
1459
|
-
rate, key, err = jax_rate(x, params, key)
|
|
1326
|
+
rate, key, err, params = jax_rate(x, params, key)
|
|
1460
1327
|
key, subkey = random.split(key)
|
|
1461
|
-
|
|
1462
|
-
sample =
|
|
1328
|
+
sample, params = jax_op(subkey, rate, params)
|
|
1329
|
+
sample = sample.astype(self.INT)
|
|
1463
1330
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1464
1331
|
err |= (out_of_bounds * ERR)
|
|
1465
|
-
return sample, key, err
|
|
1332
|
+
return sample, key, err, params
|
|
1466
1333
|
|
|
1467
1334
|
return _jax_wrapped_distribution_poisson
|
|
1468
1335
|
|
|
1469
|
-
def _jax_gamma(self, expr,
|
|
1336
|
+
def _jax_gamma(self, expr, init_params):
|
|
1470
1337
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GAMMA']
|
|
1471
1338
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1472
1339
|
|
|
1473
1340
|
arg_shape, arg_scale = expr.args
|
|
1474
|
-
jax_shape = self._jax(arg_shape,
|
|
1475
|
-
jax_scale = self._jax(arg_scale,
|
|
1341
|
+
jax_shape = self._jax(arg_shape, init_params)
|
|
1342
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1476
1343
|
|
|
1477
1344
|
# partial reparameterization trick Gamma(s, r) = r * Gamma(s, 1)
|
|
1478
1345
|
# uses the implicit JAX subroutine for Gamma(s, 1)
|
|
1479
1346
|
def _jax_wrapped_distribution_gamma(x, params, key):
|
|
1480
|
-
shape, key, err1 = jax_shape(x, params, key)
|
|
1481
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1347
|
+
shape, key, err1, params = jax_shape(x, params, key)
|
|
1348
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1482
1349
|
key, subkey = random.split(key)
|
|
1483
1350
|
Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
|
|
1484
1351
|
sample = scale * Gamma
|
|
1485
1352
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1486
1353
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1487
|
-
return sample, key, err
|
|
1354
|
+
return sample, key, err, params
|
|
1488
1355
|
|
|
1489
1356
|
return _jax_wrapped_distribution_gamma
|
|
1490
1357
|
|
|
1491
|
-
def _jax_binomial(self, expr,
|
|
1358
|
+
def _jax_binomial(self, expr, init_params):
|
|
1492
1359
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1493
1360
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1494
1361
|
|
|
1495
1362
|
arg_trials, arg_prob = expr.args
|
|
1496
|
-
jax_trials = self._jax(arg_trials,
|
|
1497
|
-
jax_prob = self._jax(arg_prob,
|
|
1363
|
+
jax_trials = self._jax(arg_trials, init_params)
|
|
1364
|
+
jax_prob = self._jax(arg_prob, init_params)
|
|
1498
1365
|
|
|
1499
1366
|
# uses the JAX substrate of tensorflow-probability
|
|
1500
1367
|
def _jax_wrapped_distribution_binomial(x, params, key):
|
|
1501
|
-
trials, key, err2 = jax_trials(x, params, key)
|
|
1502
|
-
prob, key, err1 = jax_prob(x, params, key)
|
|
1368
|
+
trials, key, err2, params = jax_trials(x, params, key)
|
|
1369
|
+
prob, key, err1, params = jax_prob(x, params, key)
|
|
1503
1370
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1504
1371
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1505
1372
|
key, subkey = random.split(key)
|
|
@@ -1508,55 +1375,56 @@ class JaxRDDLCompiler:
|
|
|
1508
1375
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1509
1376
|
(prob >= 0) & (prob <= 1) & (trials >= 0)))
|
|
1510
1377
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1511
|
-
return sample, key, err
|
|
1378
|
+
return sample, key, err, params
|
|
1512
1379
|
|
|
1513
1380
|
return _jax_wrapped_distribution_binomial
|
|
1514
1381
|
|
|
1515
|
-
def _jax_negative_binomial(self, expr,
|
|
1382
|
+
def _jax_negative_binomial(self, expr, init_params):
|
|
1516
1383
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1517
1384
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1518
1385
|
|
|
1519
1386
|
arg_trials, arg_prob = expr.args
|
|
1520
|
-
jax_trials = self._jax(arg_trials,
|
|
1521
|
-
jax_prob = self._jax(arg_prob,
|
|
1387
|
+
jax_trials = self._jax(arg_trials, init_params)
|
|
1388
|
+
jax_prob = self._jax(arg_prob, init_params)
|
|
1522
1389
|
|
|
1523
1390
|
# uses the JAX substrate of tensorflow-probability
|
|
1524
1391
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1525
|
-
trials, key, err2 = jax_trials(x, params, key)
|
|
1526
|
-
prob, key, err1 = jax_prob(x, params, key)
|
|
1392
|
+
trials, key, err2, params = jax_trials(x, params, key)
|
|
1393
|
+
prob, key, err1, params = jax_prob(x, params, key)
|
|
1527
1394
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1528
1395
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1529
1396
|
key, subkey = random.split(key)
|
|
1530
1397
|
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
|
|
1531
1398
|
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
1532
1399
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1533
|
-
(prob >= 0) & (prob <= 1) & (trials > 0))
|
|
1400
|
+
(prob >= 0) & (prob <= 1) & (trials > 0))
|
|
1401
|
+
)
|
|
1534
1402
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1535
|
-
return sample, key, err
|
|
1403
|
+
return sample, key, err, params
|
|
1536
1404
|
|
|
1537
1405
|
return _jax_wrapped_distribution_negative_binomial
|
|
1538
1406
|
|
|
1539
|
-
def _jax_beta(self, expr,
|
|
1407
|
+
def _jax_beta(self, expr, init_params):
|
|
1540
1408
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BETA']
|
|
1541
1409
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1542
1410
|
|
|
1543
1411
|
arg_shape, arg_rate = expr.args
|
|
1544
|
-
jax_shape = self._jax(arg_shape,
|
|
1545
|
-
jax_rate = self._jax(arg_rate,
|
|
1412
|
+
jax_shape = self._jax(arg_shape, init_params)
|
|
1413
|
+
jax_rate = self._jax(arg_rate, init_params)
|
|
1546
1414
|
|
|
1547
1415
|
# uses the implicit JAX subroutine
|
|
1548
1416
|
def _jax_wrapped_distribution_beta(x, params, key):
|
|
1549
|
-
shape, key, err1 = jax_shape(x, params, key)
|
|
1550
|
-
rate, key, err2 = jax_rate(x, params, key)
|
|
1417
|
+
shape, key, err1, params = jax_shape(x, params, key)
|
|
1418
|
+
rate, key, err2, params = jax_rate(x, params, key)
|
|
1551
1419
|
key, subkey = random.split(key)
|
|
1552
1420
|
sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
|
|
1553
1421
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
|
|
1554
1422
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1555
|
-
return sample, key, err
|
|
1423
|
+
return sample, key, err, params
|
|
1556
1424
|
|
|
1557
1425
|
return _jax_wrapped_distribution_beta
|
|
1558
1426
|
|
|
1559
|
-
def _jax_geometric(self, expr,
|
|
1427
|
+
def _jax_geometric(self, expr, init_params):
|
|
1560
1428
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1561
1429
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1562
1430
|
arg_prob, = expr.args
|
|
@@ -1566,187 +1434,187 @@ class JaxRDDLCompiler:
|
|
|
1566
1434
|
geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
1567
1435
|
else:
|
|
1568
1436
|
geom_op = self.GEOMETRIC_HELPER
|
|
1569
|
-
|
|
1437
|
+
jax_op = geom_op(expr.id, init_params)
|
|
1570
1438
|
|
|
1571
1439
|
# recursively compile arguments
|
|
1572
|
-
jax_prob = self._jax(arg_prob,
|
|
1440
|
+
jax_prob = self._jax(arg_prob, init_params)
|
|
1573
1441
|
|
|
1574
1442
|
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1575
|
-
prob, key, err = jax_prob(x, params, key)
|
|
1443
|
+
prob, key, err, params = jax_prob(x, params, key)
|
|
1576
1444
|
key, subkey = random.split(key)
|
|
1577
|
-
|
|
1578
|
-
sample =
|
|
1445
|
+
sample, params = jax_op(subkey, prob, params)
|
|
1446
|
+
sample = sample.astype(self.INT)
|
|
1579
1447
|
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1580
1448
|
err |= (out_of_bounds * ERR)
|
|
1581
|
-
return sample, key, err
|
|
1449
|
+
return sample, key, err, params
|
|
1582
1450
|
|
|
1583
1451
|
return _jax_wrapped_distribution_geometric
|
|
1584
1452
|
|
|
1585
|
-
def _jax_pareto(self, expr,
|
|
1453
|
+
def _jax_pareto(self, expr, init_params):
|
|
1586
1454
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_PARETO']
|
|
1587
1455
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1588
1456
|
|
|
1589
1457
|
arg_shape, arg_scale = expr.args
|
|
1590
|
-
jax_shape = self._jax(arg_shape,
|
|
1591
|
-
jax_scale = self._jax(arg_scale,
|
|
1458
|
+
jax_shape = self._jax(arg_shape, init_params)
|
|
1459
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1592
1460
|
|
|
1593
1461
|
# partial reparameterization trick Pareto(s, r) = r * Pareto(s, 1)
|
|
1594
1462
|
# uses the implicit JAX subroutine for Pareto(s, 1)
|
|
1595
1463
|
def _jax_wrapped_distribution_pareto(x, params, key):
|
|
1596
|
-
shape, key, err1 = jax_shape(x, params, key)
|
|
1597
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1464
|
+
shape, key, err1, params = jax_shape(x, params, key)
|
|
1465
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1598
1466
|
key, subkey = random.split(key)
|
|
1599
1467
|
sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
|
|
1600
1468
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1601
1469
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1602
|
-
return sample, key, err
|
|
1470
|
+
return sample, key, err, params
|
|
1603
1471
|
|
|
1604
1472
|
return _jax_wrapped_distribution_pareto
|
|
1605
1473
|
|
|
1606
|
-
def _jax_student(self, expr,
|
|
1474
|
+
def _jax_student(self, expr, init_params):
|
|
1607
1475
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_STUDENT']
|
|
1608
1476
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1609
1477
|
|
|
1610
1478
|
arg_df, = expr.args
|
|
1611
|
-
jax_df = self._jax(arg_df,
|
|
1479
|
+
jax_df = self._jax(arg_df, init_params)
|
|
1612
1480
|
|
|
1613
1481
|
# uses the implicit JAX subroutine for student(df)
|
|
1614
1482
|
def _jax_wrapped_distribution_t(x, params, key):
|
|
1615
|
-
df, key, err = jax_df(x, params, key)
|
|
1483
|
+
df, key, err, params = jax_df(x, params, key)
|
|
1616
1484
|
key, subkey = random.split(key)
|
|
1617
1485
|
sample = random.t(
|
|
1618
1486
|
key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1619
1487
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1620
1488
|
err |= (out_of_bounds * ERR)
|
|
1621
|
-
return sample, key, err
|
|
1489
|
+
return sample, key, err, params
|
|
1622
1490
|
|
|
1623
1491
|
return _jax_wrapped_distribution_t
|
|
1624
1492
|
|
|
1625
|
-
def _jax_gumbel(self, expr,
|
|
1493
|
+
def _jax_gumbel(self, expr, init_params):
|
|
1626
1494
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GUMBEL']
|
|
1627
1495
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1628
1496
|
|
|
1629
1497
|
arg_mean, arg_scale = expr.args
|
|
1630
|
-
jax_mean = self._jax(arg_mean,
|
|
1631
|
-
jax_scale = self._jax(arg_scale,
|
|
1498
|
+
jax_mean = self._jax(arg_mean, init_params)
|
|
1499
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1632
1500
|
|
|
1633
1501
|
# reparameterization trick Gumbel(m, s) = m + s * Gumbel(0, 1)
|
|
1634
1502
|
def _jax_wrapped_distribution_gumbel(x, params, key):
|
|
1635
|
-
mean, key, err1 = jax_mean(x, params, key)
|
|
1636
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1503
|
+
mean, key, err1, params = jax_mean(x, params, key)
|
|
1504
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1637
1505
|
key, subkey = random.split(key)
|
|
1638
1506
|
Gumbel01 = random.gumbel(
|
|
1639
1507
|
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1640
1508
|
sample = mean + scale * Gumbel01
|
|
1641
1509
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1642
1510
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1643
|
-
return sample, key, err
|
|
1511
|
+
return sample, key, err, params
|
|
1644
1512
|
|
|
1645
1513
|
return _jax_wrapped_distribution_gumbel
|
|
1646
1514
|
|
|
1647
|
-
def _jax_laplace(self, expr,
|
|
1515
|
+
def _jax_laplace(self, expr, init_params):
|
|
1648
1516
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_LAPLACE']
|
|
1649
1517
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1650
1518
|
|
|
1651
1519
|
arg_mean, arg_scale = expr.args
|
|
1652
|
-
jax_mean = self._jax(arg_mean,
|
|
1653
|
-
jax_scale = self._jax(arg_scale,
|
|
1520
|
+
jax_mean = self._jax(arg_mean, init_params)
|
|
1521
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1654
1522
|
|
|
1655
1523
|
# reparameterization trick Laplace(m, s) = m + s * Laplace(0, 1)
|
|
1656
1524
|
def _jax_wrapped_distribution_laplace(x, params, key):
|
|
1657
|
-
mean, key, err1 = jax_mean(x, params, key)
|
|
1658
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1525
|
+
mean, key, err1, params = jax_mean(x, params, key)
|
|
1526
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1659
1527
|
key, subkey = random.split(key)
|
|
1660
1528
|
Laplace01 = random.laplace(
|
|
1661
1529
|
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1662
1530
|
sample = mean + scale * Laplace01
|
|
1663
1531
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1664
1532
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1665
|
-
return sample, key, err
|
|
1533
|
+
return sample, key, err, params
|
|
1666
1534
|
|
|
1667
1535
|
return _jax_wrapped_distribution_laplace
|
|
1668
1536
|
|
|
1669
|
-
def _jax_cauchy(self, expr,
|
|
1537
|
+
def _jax_cauchy(self, expr, init_params):
|
|
1670
1538
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CAUCHY']
|
|
1671
1539
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1672
1540
|
|
|
1673
1541
|
arg_mean, arg_scale = expr.args
|
|
1674
|
-
jax_mean = self._jax(arg_mean,
|
|
1675
|
-
jax_scale = self._jax(arg_scale,
|
|
1542
|
+
jax_mean = self._jax(arg_mean, init_params)
|
|
1543
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1676
1544
|
|
|
1677
1545
|
# reparameterization trick Cauchy(m, s) = m + s * Cauchy(0, 1)
|
|
1678
1546
|
def _jax_wrapped_distribution_cauchy(x, params, key):
|
|
1679
|
-
mean, key, err1 = jax_mean(x, params, key)
|
|
1680
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1547
|
+
mean, key, err1, params = jax_mean(x, params, key)
|
|
1548
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1681
1549
|
key, subkey = random.split(key)
|
|
1682
1550
|
Cauchy01 = random.cauchy(
|
|
1683
1551
|
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1684
1552
|
sample = mean + scale * Cauchy01
|
|
1685
1553
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1686
1554
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1687
|
-
return sample, key, err
|
|
1555
|
+
return sample, key, err, params
|
|
1688
1556
|
|
|
1689
1557
|
return _jax_wrapped_distribution_cauchy
|
|
1690
1558
|
|
|
1691
|
-
def _jax_gompertz(self, expr,
|
|
1559
|
+
def _jax_gompertz(self, expr, init_params):
|
|
1692
1560
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GOMPERTZ']
|
|
1693
1561
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1694
1562
|
|
|
1695
1563
|
arg_shape, arg_scale = expr.args
|
|
1696
|
-
jax_shape = self._jax(arg_shape,
|
|
1697
|
-
jax_scale = self._jax(arg_scale,
|
|
1564
|
+
jax_shape = self._jax(arg_shape, init_params)
|
|
1565
|
+
jax_scale = self._jax(arg_scale, init_params)
|
|
1698
1566
|
|
|
1699
1567
|
# reparameterization trick Gompertz(s, r) = ln(1 - log(U(0, 1)) / s) / r
|
|
1700
1568
|
def _jax_wrapped_distribution_gompertz(x, params, key):
|
|
1701
|
-
shape, key, err1 = jax_shape(x, params, key)
|
|
1702
|
-
scale, key, err2 = jax_scale(x, params, key)
|
|
1569
|
+
shape, key, err1, params = jax_shape(x, params, key)
|
|
1570
|
+
scale, key, err2, params = jax_scale(x, params, key)
|
|
1703
1571
|
key, subkey = random.split(key)
|
|
1704
1572
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1705
1573
|
sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
|
|
1706
1574
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1707
1575
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1708
|
-
return sample, key, err
|
|
1576
|
+
return sample, key, err, params
|
|
1709
1577
|
|
|
1710
1578
|
return _jax_wrapped_distribution_gompertz
|
|
1711
1579
|
|
|
1712
|
-
def _jax_chisquare(self, expr,
|
|
1580
|
+
def _jax_chisquare(self, expr, init_params):
|
|
1713
1581
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_CHISQUARE']
|
|
1714
1582
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1715
1583
|
|
|
1716
1584
|
arg_df, = expr.args
|
|
1717
|
-
jax_df = self._jax(arg_df,
|
|
1585
|
+
jax_df = self._jax(arg_df, init_params)
|
|
1718
1586
|
|
|
1719
1587
|
# use the fact that ChiSquare(df) = Gamma(df/2, 2)
|
|
1720
1588
|
def _jax_wrapped_distribution_chisquare(x, params, key):
|
|
1721
|
-
df, key, err1 = jax_df(x, params, key)
|
|
1589
|
+
df, key, err1, params = jax_df(x, params, key)
|
|
1722
1590
|
key, subkey = random.split(key)
|
|
1723
1591
|
shape = df / 2.0
|
|
1724
1592
|
Gamma = random.gamma(key=subkey, a=shape, dtype=self.REAL)
|
|
1725
1593
|
sample = 2.0 * Gamma
|
|
1726
1594
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1727
1595
|
err = err1 | (out_of_bounds * ERR)
|
|
1728
|
-
return sample, key, err
|
|
1596
|
+
return sample, key, err, params
|
|
1729
1597
|
|
|
1730
1598
|
return _jax_wrapped_distribution_chisquare
|
|
1731
1599
|
|
|
1732
|
-
def _jax_kumaraswamy(self, expr,
|
|
1600
|
+
def _jax_kumaraswamy(self, expr, init_params):
|
|
1733
1601
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KUMARASWAMY']
|
|
1734
1602
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1735
1603
|
|
|
1736
1604
|
arg_a, arg_b = expr.args
|
|
1737
|
-
jax_a = self._jax(arg_a,
|
|
1738
|
-
jax_b = self._jax(arg_b,
|
|
1605
|
+
jax_a = self._jax(arg_a, init_params)
|
|
1606
|
+
jax_b = self._jax(arg_b, init_params)
|
|
1739
1607
|
|
|
1740
1608
|
# uses the reparameterization K(a, b) = (1 - (1 - U(0, 1))^{1/b})^{1/a}
|
|
1741
1609
|
def _jax_wrapped_distribution_kumaraswamy(x, params, key):
|
|
1742
|
-
a, key, err1 = jax_a(x, params, key)
|
|
1743
|
-
b, key, err2 = jax_b(x, params, key)
|
|
1610
|
+
a, key, err1, params = jax_a(x, params, key)
|
|
1611
|
+
b, key, err2, params = jax_b(x, params, key)
|
|
1744
1612
|
key, subkey = random.split(key)
|
|
1745
1613
|
U = random.uniform(key=subkey, shape=jnp.shape(a), dtype=self.REAL)
|
|
1746
1614
|
sample = jnp.power(1.0 - jnp.power(U, 1.0 / b), 1.0 / a)
|
|
1747
1615
|
out_of_bounds = jnp.logical_not(jnp.all((a > 0) & (b > 0)))
|
|
1748
1616
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1749
|
-
return sample, key, err
|
|
1617
|
+
return sample, key, err, params
|
|
1750
1618
|
|
|
1751
1619
|
return _jax_wrapped_distribution_kumaraswamy
|
|
1752
1620
|
|
|
@@ -1754,7 +1622,7 @@ class JaxRDDLCompiler:
|
|
|
1754
1622
|
# random variables with enum support
|
|
1755
1623
|
# ===========================================================================
|
|
1756
1624
|
|
|
1757
|
-
def _jax_discrete(self, expr,
|
|
1625
|
+
def _jax_discrete(self, expr, init_params, unnorm):
|
|
1758
1626
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
1759
1627
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1760
1628
|
ordered_args = self.traced.cached_sim_info(expr)
|
|
@@ -1766,10 +1634,10 @@ class JaxRDDLCompiler:
|
|
|
1766
1634
|
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1767
1635
|
else:
|
|
1768
1636
|
discrete_op = self.DISCRETE_HELPER
|
|
1769
|
-
|
|
1637
|
+
jax_op = discrete_op(expr.id, init_params)
|
|
1770
1638
|
|
|
1771
1639
|
# compile probability expressions
|
|
1772
|
-
jax_probs = [self._jax(arg,
|
|
1640
|
+
jax_probs = [self._jax(arg, init_params) for arg in ordered_args]
|
|
1773
1641
|
|
|
1774
1642
|
def _jax_wrapped_distribution_discrete(x, params, key):
|
|
1775
1643
|
|
|
@@ -1777,7 +1645,7 @@ class JaxRDDLCompiler:
|
|
|
1777
1645
|
error = NORMAL
|
|
1778
1646
|
prob = [None] * len(jax_probs)
|
|
1779
1647
|
for (i, jax_prob) in enumerate(jax_probs):
|
|
1780
|
-
prob[i], key, error_pdf = jax_prob(x, params, key)
|
|
1648
|
+
prob[i], key, error_pdf, params = jax_prob(x, params, key)
|
|
1781
1649
|
error |= error_pdf
|
|
1782
1650
|
prob = jnp.stack(prob, axis=-1)
|
|
1783
1651
|
if unnorm:
|
|
@@ -1786,17 +1654,17 @@ class JaxRDDLCompiler:
|
|
|
1786
1654
|
|
|
1787
1655
|
# dispatch to sampling subroutine
|
|
1788
1656
|
key, subkey = random.split(key)
|
|
1789
|
-
|
|
1790
|
-
sample = jax_discrete(subkey, prob, param)
|
|
1657
|
+
sample, params = jax_op(subkey, prob, params)
|
|
1791
1658
|
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1792
1659
|
jnp.all(prob >= 0),
|
|
1793
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1660
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1661
|
+
))
|
|
1794
1662
|
error |= (out_of_bounds * ERR)
|
|
1795
|
-
return sample, key, error
|
|
1663
|
+
return sample, key, error, params
|
|
1796
1664
|
|
|
1797
1665
|
return _jax_wrapped_distribution_discrete
|
|
1798
1666
|
|
|
1799
|
-
def _jax_discrete_pvar(self, expr,
|
|
1667
|
+
def _jax_discrete_pvar(self, expr, init_params, unnorm):
|
|
1800
1668
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1801
1669
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1802
1670
|
_, args = expr.args
|
|
@@ -1807,28 +1675,28 @@ class JaxRDDLCompiler:
|
|
|
1807
1675
|
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1808
1676
|
else:
|
|
1809
1677
|
discrete_op = self.DISCRETE_HELPER
|
|
1810
|
-
|
|
1678
|
+
jax_op = discrete_op(expr.id, init_params)
|
|
1811
1679
|
|
|
1812
1680
|
# compile probability function
|
|
1813
|
-
jax_probs = self._jax(arg,
|
|
1681
|
+
jax_probs = self._jax(arg, init_params)
|
|
1814
1682
|
|
|
1815
1683
|
def _jax_wrapped_distribution_discrete_pvar(x, params, key):
|
|
1816
1684
|
|
|
1817
1685
|
# sample probabilities
|
|
1818
|
-
prob, key, error = jax_probs(x, params, key)
|
|
1686
|
+
prob, key, error, params = jax_probs(x, params, key)
|
|
1819
1687
|
if unnorm:
|
|
1820
1688
|
normalizer = jnp.sum(prob, axis=-1, keepdims=True)
|
|
1821
1689
|
prob = prob / normalizer
|
|
1822
1690
|
|
|
1823
1691
|
# dispatch to sampling subroutine
|
|
1824
1692
|
key, subkey = random.split(key)
|
|
1825
|
-
|
|
1826
|
-
sample = jax_discrete(subkey, prob, param)
|
|
1693
|
+
sample, params = jax_op(subkey, prob, params)
|
|
1827
1694
|
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1828
1695
|
jnp.all(prob >= 0),
|
|
1829
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1696
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1697
|
+
))
|
|
1830
1698
|
error |= (out_of_bounds * ERR)
|
|
1831
|
-
return sample, key, error
|
|
1699
|
+
return sample, key, error, params
|
|
1832
1700
|
|
|
1833
1701
|
return _jax_wrapped_distribution_discrete_pvar
|
|
1834
1702
|
|
|
@@ -1836,68 +1704,69 @@ class JaxRDDLCompiler:
|
|
|
1836
1704
|
# random vectors
|
|
1837
1705
|
# ===========================================================================
|
|
1838
1706
|
|
|
1839
|
-
def _jax_random_vector(self, expr,
|
|
1707
|
+
def _jax_random_vector(self, expr, init_params):
|
|
1840
1708
|
_, name = expr.etype
|
|
1841
1709
|
if name == 'MultivariateNormal':
|
|
1842
|
-
return self._jax_multivariate_normal(expr,
|
|
1710
|
+
return self._jax_multivariate_normal(expr, init_params)
|
|
1843
1711
|
elif name == 'MultivariateStudent':
|
|
1844
|
-
return self._jax_multivariate_student(expr,
|
|
1712
|
+
return self._jax_multivariate_student(expr, init_params)
|
|
1845
1713
|
elif name == 'Dirichlet':
|
|
1846
|
-
return self._jax_dirichlet(expr,
|
|
1714
|
+
return self._jax_dirichlet(expr, init_params)
|
|
1847
1715
|
elif name == 'Multinomial':
|
|
1848
|
-
return self._jax_multinomial(expr,
|
|
1716
|
+
return self._jax_multinomial(expr, init_params)
|
|
1849
1717
|
else:
|
|
1850
1718
|
raise RDDLNotImplementedError(
|
|
1851
1719
|
f'Distribution {name} is not supported.\n' +
|
|
1852
1720
|
print_stack_trace(expr))
|
|
1853
1721
|
|
|
1854
|
-
def _jax_multivariate_normal(self, expr,
|
|
1722
|
+
def _jax_multivariate_normal(self, expr, init_params):
|
|
1855
1723
|
_, args = expr.args
|
|
1856
1724
|
mean, cov = args
|
|
1857
|
-
jax_mean = self._jax(mean,
|
|
1858
|
-
jax_cov = self._jax(cov,
|
|
1725
|
+
jax_mean = self._jax(mean, init_params)
|
|
1726
|
+
jax_cov = self._jax(cov, init_params)
|
|
1859
1727
|
index, = self.traced.cached_sim_info(expr)
|
|
1860
1728
|
|
|
1861
1729
|
# reparameterization trick MN(m, LL') = LZ + m, where Z ~ Normal(0, 1)
|
|
1862
1730
|
def _jax_wrapped_distribution_multivariate_normal(x, params, key):
|
|
1863
1731
|
|
|
1864
1732
|
# sample the mean and covariance
|
|
1865
|
-
sample_mean, key, err1 = jax_mean(x, params, key)
|
|
1866
|
-
sample_cov, key, err2 = jax_cov(x, params, key)
|
|
1733
|
+
sample_mean, key, err1, params = jax_mean(x, params, key)
|
|
1734
|
+
sample_cov, key, err2, params = jax_cov(x, params, key)
|
|
1867
1735
|
|
|
1868
1736
|
# sample Normal(0, 1)
|
|
1869
1737
|
key, subkey = random.split(key)
|
|
1870
1738
|
Z = random.normal(
|
|
1871
1739
|
key=subkey,
|
|
1872
1740
|
shape=jnp.shape(sample_mean) + (1,),
|
|
1873
|
-
dtype=self.REAL
|
|
1741
|
+
dtype=self.REAL
|
|
1742
|
+
)
|
|
1874
1743
|
|
|
1875
1744
|
# compute L s.t. cov = L * L' and reparameterize
|
|
1876
1745
|
L = jnp.linalg.cholesky(sample_cov)
|
|
1877
1746
|
sample = jnp.matmul(L, Z)[..., 0] + sample_mean
|
|
1878
1747
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1879
1748
|
err = err1 | err2
|
|
1880
|
-
return sample, key, err
|
|
1749
|
+
return sample, key, err, params
|
|
1881
1750
|
|
|
1882
1751
|
return _jax_wrapped_distribution_multivariate_normal
|
|
1883
1752
|
|
|
1884
|
-
def _jax_multivariate_student(self, expr,
|
|
1753
|
+
def _jax_multivariate_student(self, expr, init_params):
|
|
1885
1754
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTIVARIATE_STUDENT']
|
|
1886
1755
|
|
|
1887
1756
|
_, args = expr.args
|
|
1888
1757
|
mean, cov, df = args
|
|
1889
|
-
jax_mean = self._jax(mean,
|
|
1890
|
-
jax_cov = self._jax(cov,
|
|
1891
|
-
jax_df = self._jax(df,
|
|
1758
|
+
jax_mean = self._jax(mean, init_params)
|
|
1759
|
+
jax_cov = self._jax(cov, init_params)
|
|
1760
|
+
jax_df = self._jax(df, init_params)
|
|
1892
1761
|
index, = self.traced.cached_sim_info(expr)
|
|
1893
1762
|
|
|
1894
1763
|
# reparameterization trick MN(m, LL') = LZ + m, where Z ~ StudentT(0, 1)
|
|
1895
1764
|
def _jax_wrapped_distribution_multivariate_student(x, params, key):
|
|
1896
1765
|
|
|
1897
1766
|
# sample the mean and covariance and degrees of freedom
|
|
1898
|
-
sample_mean, key, err1 = jax_mean(x, params, key)
|
|
1899
|
-
sample_cov, key, err2 = jax_cov(x, params, key)
|
|
1900
|
-
sample_df, key, err3 = jax_df(x, params, key)
|
|
1767
|
+
sample_mean, key, err1, params = jax_mean(x, params, key)
|
|
1768
|
+
sample_cov, key, err2, params = jax_cov(x, params, key)
|
|
1769
|
+
sample_df, key, err3, params = jax_df(x, params, key)
|
|
1901
1770
|
out_of_bounds = jnp.logical_not(jnp.all(sample_df > 0))
|
|
1902
1771
|
|
|
1903
1772
|
# sample StudentT(0, 1, df) -- broadcast df to same shape as cov
|
|
@@ -1908,50 +1777,51 @@ class JaxRDDLCompiler:
|
|
|
1908
1777
|
key=subkey,
|
|
1909
1778
|
df=sample_df,
|
|
1910
1779
|
shape=jnp.shape(sample_df),
|
|
1911
|
-
dtype=self.REAL
|
|
1780
|
+
dtype=self.REAL
|
|
1781
|
+
)
|
|
1912
1782
|
|
|
1913
1783
|
# compute L s.t. cov = L * L' and reparameterize
|
|
1914
1784
|
L = jnp.linalg.cholesky(sample_cov)
|
|
1915
1785
|
sample = jnp.matmul(L, Z)[..., 0] + sample_mean
|
|
1916
1786
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1917
1787
|
error = err1 | err2 | err3 | (out_of_bounds * ERR)
|
|
1918
|
-
return sample, key, error
|
|
1788
|
+
return sample, key, error, params
|
|
1919
1789
|
|
|
1920
1790
|
return _jax_wrapped_distribution_multivariate_student
|
|
1921
1791
|
|
|
1922
|
-
def _jax_dirichlet(self, expr,
|
|
1792
|
+
def _jax_dirichlet(self, expr, init_params):
|
|
1923
1793
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DIRICHLET']
|
|
1924
1794
|
|
|
1925
1795
|
_, args = expr.args
|
|
1926
1796
|
alpha, = args
|
|
1927
|
-
jax_alpha = self._jax(alpha,
|
|
1797
|
+
jax_alpha = self._jax(alpha, init_params)
|
|
1928
1798
|
index, = self.traced.cached_sim_info(expr)
|
|
1929
1799
|
|
|
1930
1800
|
# sample Gamma(alpha_i, 1) and normalize across i
|
|
1931
1801
|
def _jax_wrapped_distribution_dirichlet(x, params, key):
|
|
1932
|
-
alpha, key, error = jax_alpha(x, params, key)
|
|
1802
|
+
alpha, key, error, params = jax_alpha(x, params, key)
|
|
1933
1803
|
out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
|
|
1934
1804
|
error |= (out_of_bounds * ERR)
|
|
1935
1805
|
key, subkey = random.split(key)
|
|
1936
1806
|
Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
|
|
1937
1807
|
sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
|
|
1938
1808
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1939
|
-
return sample, key, error
|
|
1809
|
+
return sample, key, error, params
|
|
1940
1810
|
|
|
1941
1811
|
return _jax_wrapped_distribution_dirichlet
|
|
1942
1812
|
|
|
1943
|
-
def _jax_multinomial(self, expr,
|
|
1813
|
+
def _jax_multinomial(self, expr, init_params):
|
|
1944
1814
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_MULTINOMIAL']
|
|
1945
1815
|
|
|
1946
1816
|
_, args = expr.args
|
|
1947
1817
|
trials, prob = args
|
|
1948
|
-
jax_trials = self._jax(trials,
|
|
1949
|
-
jax_prob = self._jax(prob,
|
|
1818
|
+
jax_trials = self._jax(trials, init_params)
|
|
1819
|
+
jax_prob = self._jax(prob, init_params)
|
|
1950
1820
|
index, = self.traced.cached_sim_info(expr)
|
|
1951
1821
|
|
|
1952
1822
|
def _jax_wrapped_distribution_multinomial(x, params, key):
|
|
1953
|
-
trials, key, err1 = jax_trials(x, params, key)
|
|
1954
|
-
prob, key, err2 = jax_prob(x, params, key)
|
|
1823
|
+
trials, key, err1, params = jax_trials(x, params, key)
|
|
1824
|
+
prob, key, err2, params = jax_prob(x, params, key)
|
|
1955
1825
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1956
1826
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1957
1827
|
key, subkey = random.split(key)
|
|
@@ -1961,9 +1831,10 @@ class JaxRDDLCompiler:
|
|
|
1961
1831
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1962
1832
|
(prob >= 0)
|
|
1963
1833
|
& jnp.allclose(jnp.sum(prob, axis=-1), 1.0)
|
|
1964
|
-
& (trials >= 0)
|
|
1834
|
+
& (trials >= 0)
|
|
1835
|
+
))
|
|
1965
1836
|
error = err1 | err2 | (out_of_bounds * ERR)
|
|
1966
|
-
return sample, key, error
|
|
1837
|
+
return sample, key, error, params
|
|
1967
1838
|
|
|
1968
1839
|
return _jax_wrapped_distribution_multinomial
|
|
1969
1840
|
|
|
@@ -1971,57 +1842,57 @@ class JaxRDDLCompiler:
|
|
|
1971
1842
|
# matrix algebra
|
|
1972
1843
|
# ===========================================================================
|
|
1973
1844
|
|
|
1974
|
-
def _jax_matrix(self, expr,
|
|
1845
|
+
def _jax_matrix(self, expr, init_params):
|
|
1975
1846
|
_, op = expr.etype
|
|
1976
1847
|
if op == 'det':
|
|
1977
|
-
return self._jax_matrix_det(expr,
|
|
1848
|
+
return self._jax_matrix_det(expr, init_params)
|
|
1978
1849
|
elif op == 'inverse':
|
|
1979
|
-
return self._jax_matrix_inv(expr,
|
|
1850
|
+
return self._jax_matrix_inv(expr, init_params, pseudo=False)
|
|
1980
1851
|
elif op == 'pinverse':
|
|
1981
|
-
return self._jax_matrix_inv(expr,
|
|
1852
|
+
return self._jax_matrix_inv(expr, init_params, pseudo=True)
|
|
1982
1853
|
elif op == 'cholesky':
|
|
1983
|
-
return self._jax_matrix_cholesky(expr,
|
|
1854
|
+
return self._jax_matrix_cholesky(expr, init_params)
|
|
1984
1855
|
else:
|
|
1985
1856
|
raise RDDLNotImplementedError(
|
|
1986
1857
|
f'Matrix operation {op} is not supported.\n' +
|
|
1987
1858
|
print_stack_trace(expr))
|
|
1988
1859
|
|
|
1989
|
-
def _jax_matrix_det(self, expr,
|
|
1860
|
+
def _jax_matrix_det(self, expr, init_params):
|
|
1990
1861
|
* _, arg = expr.args
|
|
1991
|
-
jax_arg = self._jax(arg,
|
|
1862
|
+
jax_arg = self._jax(arg, init_params)
|
|
1992
1863
|
|
|
1993
1864
|
def _jax_wrapped_matrix_operation_det(x, params, key):
|
|
1994
|
-
sample_arg, key, error = jax_arg(x, params, key)
|
|
1865
|
+
sample_arg, key, error, params = jax_arg(x, params, key)
|
|
1995
1866
|
sample = jnp.linalg.det(sample_arg)
|
|
1996
|
-
return sample, key, error
|
|
1867
|
+
return sample, key, error, params
|
|
1997
1868
|
|
|
1998
1869
|
return _jax_wrapped_matrix_operation_det
|
|
1999
1870
|
|
|
2000
|
-
def _jax_matrix_inv(self, expr,
|
|
1871
|
+
def _jax_matrix_inv(self, expr, init_params, pseudo):
|
|
2001
1872
|
_, arg = expr.args
|
|
2002
|
-
jax_arg = self._jax(arg,
|
|
1873
|
+
jax_arg = self._jax(arg, init_params)
|
|
2003
1874
|
indices = self.traced.cached_sim_info(expr)
|
|
2004
1875
|
op = jnp.linalg.pinv if pseudo else jnp.linalg.inv
|
|
2005
1876
|
|
|
2006
1877
|
def _jax_wrapped_matrix_operation_inv(x, params, key):
|
|
2007
|
-
sample_arg, key, error = jax_arg(x, params, key)
|
|
1878
|
+
sample_arg, key, error, params = jax_arg(x, params, key)
|
|
2008
1879
|
sample = op(sample_arg)
|
|
2009
1880
|
sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
|
|
2010
|
-
return sample, key, error
|
|
1881
|
+
return sample, key, error, params
|
|
2011
1882
|
|
|
2012
1883
|
return _jax_wrapped_matrix_operation_inv
|
|
2013
1884
|
|
|
2014
|
-
def _jax_matrix_cholesky(self, expr,
|
|
1885
|
+
def _jax_matrix_cholesky(self, expr, init_params):
|
|
2015
1886
|
_, arg = expr.args
|
|
2016
|
-
jax_arg = self._jax(arg,
|
|
1887
|
+
jax_arg = self._jax(arg, init_params)
|
|
2017
1888
|
indices = self.traced.cached_sim_info(expr)
|
|
2018
1889
|
op = jnp.linalg.cholesky
|
|
2019
1890
|
|
|
2020
1891
|
def _jax_wrapped_matrix_operation_cholesky(x, params, key):
|
|
2021
|
-
sample_arg, key, error = jax_arg(x, params, key)
|
|
1892
|
+
sample_arg, key, error, params = jax_arg(x, params, key)
|
|
2022
1893
|
sample = op(sample_arg)
|
|
2023
1894
|
sample = jnp.moveaxis(sample, source=(-2, -1), destination=indices)
|
|
2024
|
-
return sample, key, error
|
|
1895
|
+
return sample, key, error, params
|
|
2025
1896
|
|
|
2026
1897
|
return _jax_wrapped_matrix_operation_cholesky
|
|
2027
1898
|
|