pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/core/compiler.py +445 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +699 -332
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +23 -12
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +1 -2
- pyRDDLGym_jax/examples/run_plan.py +7 -0
- pyRDDLGym_jax/examples/run_tune.py +6 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -4,7 +4,7 @@ import jax.numpy as jnp
|
|
|
4
4
|
import jax.random as random
|
|
5
5
|
import jax.scipy as scipy
|
|
6
6
|
import traceback
|
|
7
|
-
from typing import Callable, Dict, List
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
10
10
|
|
|
@@ -13,8 +13,9 @@ try:
|
|
|
13
13
|
from tensorflow_probability.substrates import jax as tfp
|
|
14
14
|
except Exception:
|
|
15
15
|
raise_warning('Failed to import tensorflow-probability: '
|
|
16
|
-
'compilation of some complex distributions
|
|
17
|
-
'
|
|
16
|
+
'compilation of some complex distributions '
|
|
17
|
+
'(Binomial, Negative-Binomial, Multinomial) will fail. '
|
|
18
|
+
'Please ensure this package is installed correctly.', 'red')
|
|
18
19
|
traceback.print_exc()
|
|
19
20
|
tfp = None
|
|
20
21
|
|
|
@@ -32,6 +33,98 @@ from pyRDDLGym.core.debug.logger import Logger
|
|
|
32
33
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
# ===========================================================================
|
|
37
|
+
# EXACT RDDL TO JAX COMPILATION RULES
|
|
38
|
+
# ===========================================================================
|
|
39
|
+
|
|
40
|
+
def _function_unary_exact_named(op, name):
|
|
41
|
+
|
|
42
|
+
def _jax_wrapped_unary_fn_exact(x, param):
|
|
43
|
+
return op(x)
|
|
44
|
+
|
|
45
|
+
return _jax_wrapped_unary_fn_exact
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _function_unary_exact_named_gamma():
|
|
49
|
+
|
|
50
|
+
def _jax_wrapped_unary_gamma_exact(x, param):
|
|
51
|
+
return jnp.exp(scipy.special.gammaln(x))
|
|
52
|
+
|
|
53
|
+
return _jax_wrapped_unary_gamma_exact
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _function_binary_exact_named(op, name):
|
|
57
|
+
|
|
58
|
+
def _jax_wrapped_binary_fn_exact(x, y, param):
|
|
59
|
+
return op(x, y)
|
|
60
|
+
|
|
61
|
+
return _jax_wrapped_binary_fn_exact
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _function_binary_exact_named_implies():
|
|
65
|
+
|
|
66
|
+
def _jax_wrapped_binary_implies_exact(x, y, param):
|
|
67
|
+
return jnp.logical_or(jnp.logical_not(x), y)
|
|
68
|
+
|
|
69
|
+
return _jax_wrapped_binary_implies_exact
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _function_binary_exact_named_log():
|
|
73
|
+
|
|
74
|
+
def _jax_wrapped_binary_log_exact(x, y, param):
|
|
75
|
+
return jnp.log(x) / jnp.log(y)
|
|
76
|
+
|
|
77
|
+
return _jax_wrapped_binary_log_exact
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _function_aggregation_exact_named(op, name):
|
|
81
|
+
|
|
82
|
+
def _jax_wrapped_aggregation_fn_exact(x, axis, param):
|
|
83
|
+
return op(x, axis=axis)
|
|
84
|
+
|
|
85
|
+
return _jax_wrapped_aggregation_fn_exact
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _function_if_exact_named():
|
|
89
|
+
|
|
90
|
+
def _jax_wrapped_if_exact(c, a, b, param):
|
|
91
|
+
return jnp.where(c, a, b)
|
|
92
|
+
|
|
93
|
+
return _jax_wrapped_if_exact
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _function_switch_exact_named():
|
|
97
|
+
|
|
98
|
+
def _jax_wrapped_switch_exact(pred, cases, param):
|
|
99
|
+
pred = pred[jnp.newaxis, ...]
|
|
100
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
101
|
+
assert sample.shape[0] == 1
|
|
102
|
+
return sample[0, ...]
|
|
103
|
+
|
|
104
|
+
return _jax_wrapped_switch_exact
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _function_bernoulli_exact_named():
|
|
108
|
+
|
|
109
|
+
def _jax_wrapped_bernoulli_exact(key, prob, param):
|
|
110
|
+
return random.bernoulli(key, prob)
|
|
111
|
+
|
|
112
|
+
return _jax_wrapped_bernoulli_exact
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _function_discrete_exact_named():
|
|
116
|
+
|
|
117
|
+
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
118
|
+
logits = jnp.log(prob)
|
|
119
|
+
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
120
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
121
|
+
jnp.all(prob >= 0),
|
|
122
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
123
|
+
return sample, out_of_bounds
|
|
124
|
+
|
|
125
|
+
return _jax_wrapped_discrete_exact
|
|
126
|
+
|
|
127
|
+
|
|
35
128
|
class JaxRDDLCompiler:
|
|
36
129
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
37
130
|
All operations are identical to their numpy equivalents.
|
|
@@ -39,10 +132,97 @@ class JaxRDDLCompiler:
|
|
|
39
132
|
|
|
40
133
|
MODEL_PARAM_TAG_SEPARATOR = '___'
|
|
41
134
|
|
|
135
|
+
# ===========================================================================
|
|
136
|
+
# EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
|
|
137
|
+
# ===========================================================================
|
|
138
|
+
|
|
139
|
+
EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
|
|
140
|
+
|
|
141
|
+
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
142
|
+
'+': _function_binary_exact_named(jnp.add, 'add'),
|
|
143
|
+
'-': _function_binary_exact_named(jnp.subtract, 'subtract'),
|
|
144
|
+
'*': _function_binary_exact_named(jnp.multiply, 'multiply'),
|
|
145
|
+
'/': _function_binary_exact_named(jnp.divide, 'divide')
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
149
|
+
'>=': _function_binary_exact_named(jnp.greater_equal, 'greater_equal'),
|
|
150
|
+
'<=': _function_binary_exact_named(jnp.less_equal, 'less_equal'),
|
|
151
|
+
'<': _function_binary_exact_named(jnp.less, 'less'),
|
|
152
|
+
'>': _function_binary_exact_named(jnp.greater, 'greater'),
|
|
153
|
+
'==': _function_binary_exact_named(jnp.equal, 'equal'),
|
|
154
|
+
'~=': _function_binary_exact_named(jnp.not_equal, 'not_equal')
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
158
|
+
'^': _function_binary_exact_named(jnp.logical_and, 'and'),
|
|
159
|
+
'&': _function_binary_exact_named(jnp.logical_and, 'and'),
|
|
160
|
+
'|': _function_binary_exact_named(jnp.logical_or, 'or'),
|
|
161
|
+
'~': _function_binary_exact_named(jnp.logical_xor, 'xor'),
|
|
162
|
+
'=>': _function_binary_exact_named_implies(),
|
|
163
|
+
'<=>': _function_binary_exact_named(jnp.equal, 'iff')
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
|
|
167
|
+
|
|
168
|
+
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
169
|
+
'sum': _function_aggregation_exact_named(jnp.sum, 'sum'),
|
|
170
|
+
'avg': _function_aggregation_exact_named(jnp.mean, 'avg'),
|
|
171
|
+
'prod': _function_aggregation_exact_named(jnp.prod, 'prod'),
|
|
172
|
+
'minimum': _function_aggregation_exact_named(jnp.min, 'minimum'),
|
|
173
|
+
'maximum': _function_aggregation_exact_named(jnp.max, 'maximum'),
|
|
174
|
+
'forall': _function_aggregation_exact_named(jnp.all, 'forall'),
|
|
175
|
+
'exists': _function_aggregation_exact_named(jnp.any, 'exists'),
|
|
176
|
+
'argmin': _function_aggregation_exact_named(jnp.argmin, 'argmin'),
|
|
177
|
+
'argmax': _function_aggregation_exact_named(jnp.argmax, 'argmax')
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
EXACT_RDDL_TO_JAX_UNARY = {
|
|
181
|
+
'abs': _function_unary_exact_named(jnp.abs, 'abs'),
|
|
182
|
+
'sgn': _function_unary_exact_named(jnp.sign, 'sgn'),
|
|
183
|
+
'round': _function_unary_exact_named(jnp.round, 'round'),
|
|
184
|
+
'floor': _function_unary_exact_named(jnp.floor, 'floor'),
|
|
185
|
+
'ceil': _function_unary_exact_named(jnp.ceil, 'ceil'),
|
|
186
|
+
'cos': _function_unary_exact_named(jnp.cos, 'cos'),
|
|
187
|
+
'sin': _function_unary_exact_named(jnp.sin, 'sin'),
|
|
188
|
+
'tan': _function_unary_exact_named(jnp.tan, 'tan'),
|
|
189
|
+
'acos': _function_unary_exact_named(jnp.arccos, 'acos'),
|
|
190
|
+
'asin': _function_unary_exact_named(jnp.arcsin, 'asin'),
|
|
191
|
+
'atan': _function_unary_exact_named(jnp.arctan, 'atan'),
|
|
192
|
+
'cosh': _function_unary_exact_named(jnp.cosh, 'cosh'),
|
|
193
|
+
'sinh': _function_unary_exact_named(jnp.sinh, 'sinh'),
|
|
194
|
+
'tanh': _function_unary_exact_named(jnp.tanh, 'tanh'),
|
|
195
|
+
'exp': _function_unary_exact_named(jnp.exp, 'exp'),
|
|
196
|
+
'ln': _function_unary_exact_named(jnp.log, 'ln'),
|
|
197
|
+
'sqrt': _function_unary_exact_named(jnp.sqrt, 'sqrt'),
|
|
198
|
+
'lngamma': _function_unary_exact_named(scipy.special.gammaln, 'lngamma'),
|
|
199
|
+
'gamma': _function_unary_exact_named_gamma()
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
EXACT_RDDL_TO_JAX_BINARY = {
|
|
203
|
+
'div': _function_binary_exact_named(jnp.floor_divide, 'div'),
|
|
204
|
+
'mod': _function_binary_exact_named(jnp.mod, 'mod'),
|
|
205
|
+
'fmod': _function_binary_exact_named(jnp.mod, 'fmod'),
|
|
206
|
+
'min': _function_binary_exact_named(jnp.minimum, 'min'),
|
|
207
|
+
'max': _function_binary_exact_named(jnp.maximum, 'max'),
|
|
208
|
+
'pow': _function_binary_exact_named(jnp.power, 'pow'),
|
|
209
|
+
'log': _function_binary_exact_named_log(),
|
|
210
|
+
'hypot': _function_binary_exact_named(jnp.hypot, 'hypot'),
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
|
|
214
|
+
|
|
215
|
+
EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
|
|
216
|
+
|
|
217
|
+
EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
|
|
218
|
+
|
|
219
|
+
EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
|
|
220
|
+
|
|
42
221
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
43
222
|
allow_synchronous_state: bool=True,
|
|
44
|
-
logger: Logger=None,
|
|
45
|
-
use64bit: bool=False
|
|
223
|
+
logger: Optional[Logger]=None,
|
|
224
|
+
use64bit: bool=False,
|
|
225
|
+
compile_non_fluent_exact: bool=True) -> None:
|
|
46
226
|
'''Creates a new RDDL to Jax compiler.
|
|
47
227
|
|
|
48
228
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -50,11 +230,14 @@ class JaxRDDLCompiler:
|
|
|
50
230
|
on each other
|
|
51
231
|
:param logger: to log information about compilation to file
|
|
52
232
|
:param use64bit: whether to use 64 bit arithmetic
|
|
233
|
+
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
234
|
+
are always compiled using exact JAX expressions.
|
|
53
235
|
'''
|
|
54
236
|
self.rddl = rddl
|
|
55
237
|
self.logger = logger
|
|
56
238
|
# jax.config.update('jax_log_compiles', True) # for testing ONLY
|
|
57
239
|
|
|
240
|
+
self.use64bit = use64bit
|
|
58
241
|
if use64bit:
|
|
59
242
|
self.INT = jnp.int64
|
|
60
243
|
self.REAL = jnp.float64
|
|
@@ -62,6 +245,7 @@ class JaxRDDLCompiler:
|
|
|
62
245
|
else:
|
|
63
246
|
self.INT = jnp.int32
|
|
64
247
|
self.REAL = jnp.float32
|
|
248
|
+
jax.config.update('jax_enable_x64', False)
|
|
65
249
|
self.ONE = jnp.asarray(1, dtype=self.INT)
|
|
66
250
|
self.JAX_TYPES = {
|
|
67
251
|
'int': self.INT,
|
|
@@ -70,17 +254,16 @@ class JaxRDDLCompiler:
|
|
|
70
254
|
}
|
|
71
255
|
|
|
72
256
|
# compile initial values
|
|
73
|
-
|
|
74
|
-
self.logger.clear()
|
|
75
|
-
initializer = RDDLValueInitializer(rddl, logger=self.logger)
|
|
257
|
+
initializer = RDDLValueInitializer(rddl)
|
|
76
258
|
self.init_values = initializer.initialize()
|
|
77
259
|
|
|
78
260
|
# compute dependency graph for CPFs and sort them by evaluation order
|
|
79
|
-
sorter = RDDLLevelAnalysis(
|
|
261
|
+
sorter = RDDLLevelAnalysis(
|
|
262
|
+
rddl, allow_synchronous_state=allow_synchronous_state)
|
|
80
263
|
self.levels = sorter.compute_levels()
|
|
81
264
|
|
|
82
265
|
# trace expressions to cache information to be used later
|
|
83
|
-
tracer = RDDLObjectsTracer(rddl,
|
|
266
|
+
tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
|
|
84
267
|
self.traced = tracer.trace()
|
|
85
268
|
|
|
86
269
|
# extract the box constraints on actions
|
|
@@ -92,92 +275,42 @@ class JaxRDDLCompiler:
|
|
|
92
275
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
93
276
|
self.constraints = constraints
|
|
94
277
|
|
|
95
|
-
# basic operations
|
|
96
|
-
self.
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
self.RELATIONAL_OPS = {
|
|
104
|
-
'>=': lambda x, y, param: jnp.greater_equal(x, y),
|
|
105
|
-
'<=': lambda x, y, param: jnp.less_equal(x, y),
|
|
106
|
-
'<': lambda x, y, param: jnp.less(x, y),
|
|
107
|
-
'>': lambda x, y, param: jnp.greater(x, y),
|
|
108
|
-
'==': lambda x, y, param: jnp.equal(x, y),
|
|
109
|
-
'~=': lambda x, y, param: jnp.not_equal(x, y)
|
|
110
|
-
}
|
|
111
|
-
self.LOGICAL_NOT = lambda x, param: jnp.logical_not(x)
|
|
112
|
-
self.LOGICAL_OPS = {
|
|
113
|
-
'^': lambda x, y, param: jnp.logical_and(x, y),
|
|
114
|
-
'&': lambda x, y, param: jnp.logical_and(x, y),
|
|
115
|
-
'|': lambda x, y, param: jnp.logical_or(x, y),
|
|
116
|
-
'~': lambda x, y, param: jnp.logical_xor(x, y),
|
|
117
|
-
'=>': lambda x, y, param: jnp.logical_or(jnp.logical_not(x), y),
|
|
118
|
-
'<=>': lambda x, y, param: jnp.equal(x, y)
|
|
119
|
-
}
|
|
120
|
-
self.AGGREGATION_OPS = {
|
|
121
|
-
'sum': lambda x, axis, param: jnp.sum(x, axis=axis),
|
|
122
|
-
'avg': lambda x, axis, param: jnp.mean(x, axis=axis),
|
|
123
|
-
'prod': lambda x, axis, param: jnp.prod(x, axis=axis),
|
|
124
|
-
'minimum': lambda x, axis, param: jnp.min(x, axis=axis),
|
|
125
|
-
'maximum': lambda x, axis, param: jnp.max(x, axis=axis),
|
|
126
|
-
'forall': lambda x, axis, param: jnp.all(x, axis=axis),
|
|
127
|
-
'exists': lambda x, axis, param: jnp.any(x, axis=axis),
|
|
128
|
-
'argmin': lambda x, axis, param: jnp.argmin(x, axis=axis),
|
|
129
|
-
'argmax': lambda x, axis, param: jnp.argmax(x, axis=axis)
|
|
130
|
-
}
|
|
278
|
+
# basic operations - these can be override in subclasses
|
|
279
|
+
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
280
|
+
self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
281
|
+
self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
|
|
282
|
+
self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
|
|
283
|
+
self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
284
|
+
self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
|
|
285
|
+
self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
|
|
131
286
|
self.AGGREGATION_BOOL = {'forall', 'exists'}
|
|
132
|
-
self.KNOWN_UNARY =
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
'sin': lambda x, param: jnp.sin(x),
|
|
140
|
-
'tan': lambda x, param: jnp.tan(x),
|
|
141
|
-
'acos': lambda x, param: jnp.arccos(x),
|
|
142
|
-
'asin': lambda x, param: jnp.arcsin(x),
|
|
143
|
-
'atan': lambda x, param: jnp.arctan(x),
|
|
144
|
-
'cosh': lambda x, param: jnp.cosh(x),
|
|
145
|
-
'sinh': lambda x, param: jnp.sinh(x),
|
|
146
|
-
'tanh': lambda x, param: jnp.tanh(x),
|
|
147
|
-
'exp': lambda x, param: jnp.exp(x),
|
|
148
|
-
'ln': lambda x, param: jnp.log(x),
|
|
149
|
-
'sqrt': lambda x, param: jnp.sqrt(x),
|
|
150
|
-
'lngamma': lambda x, param: scipy.special.gammaln(x),
|
|
151
|
-
'gamma': lambda x, param: jnp.exp(scipy.special.gammaln(x))
|
|
152
|
-
}
|
|
153
|
-
self.KNOWN_BINARY = {
|
|
154
|
-
'div': lambda x, y, param: jnp.floor_divide(x, y),
|
|
155
|
-
'mod': lambda x, y, param: jnp.mod(x, y),
|
|
156
|
-
'fmod': lambda x, y, param: jnp.mod(x, y),
|
|
157
|
-
'min': lambda x, y, param: jnp.minimum(x, y),
|
|
158
|
-
'max': lambda x, y, param: jnp.maximum(x, y),
|
|
159
|
-
'pow': lambda x, y, param: jnp.power(x, y),
|
|
160
|
-
'log': lambda x, y, param: jnp.log(x) / jnp.log(y),
|
|
161
|
-
'hypot': lambda x, y, param: jnp.hypot(x, y)
|
|
162
|
-
}
|
|
163
|
-
|
|
287
|
+
self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
|
|
288
|
+
self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
|
|
289
|
+
self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
|
|
290
|
+
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
291
|
+
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
292
|
+
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
293
|
+
|
|
164
294
|
# ===========================================================================
|
|
165
295
|
# main compilation subroutines
|
|
166
296
|
# ===========================================================================
|
|
167
297
|
|
|
168
|
-
def compile(self, log_jax_expr: bool=False) -> None:
|
|
298
|
+
def compile(self, log_jax_expr: bool=False, heading: str='') -> None:
|
|
169
299
|
'''Compiles the current RDDL into Jax expressions.
|
|
170
300
|
|
|
171
301
|
:param log_jax_expr: whether to pretty-print the compiled Jax functions
|
|
172
302
|
to the log file
|
|
303
|
+
:param heading: the heading to print before compilation information
|
|
173
304
|
'''
|
|
174
|
-
info = {}
|
|
305
|
+
info = ({}, [])
|
|
175
306
|
self.invariants = self._compile_constraints(self.rddl.invariants, info)
|
|
176
307
|
self.preconditions = self._compile_constraints(self.rddl.preconditions, info)
|
|
177
308
|
self.terminations = self._compile_constraints(self.rddl.terminations, info)
|
|
178
309
|
self.cpfs = self._compile_cpfs(info)
|
|
179
310
|
self.reward = self._compile_reward(info)
|
|
180
|
-
self.model_params =
|
|
311
|
+
self.model_params = {key: value
|
|
312
|
+
for (key, (value, *_)) in info[0].items()}
|
|
313
|
+
self.relaxations = info[1]
|
|
181
314
|
|
|
182
315
|
if log_jax_expr and self.logger is not None:
|
|
183
316
|
printed = self.print_jax()
|
|
@@ -189,6 +322,7 @@ class JaxRDDLCompiler:
|
|
|
189
322
|
printed_terminals = '\n\n'.join(v for v in printed['terminations'])
|
|
190
323
|
printed_params = '\n'.join(f'{k}: {v}' for (k, v) in info.items())
|
|
191
324
|
message = (
|
|
325
|
+
f'[info] {heading}\n'
|
|
192
326
|
f'[info] compiled JAX CPFs:\n\n'
|
|
193
327
|
f'{printed_cpfs}\n\n'
|
|
194
328
|
f'[info] compiled JAX reward:\n\n'
|
|
@@ -281,17 +415,21 @@ class JaxRDDLCompiler:
|
|
|
281
415
|
return jax_inequalities, jax_equalities
|
|
282
416
|
|
|
283
417
|
def compile_transition(self, check_constraints: bool=False,
|
|
284
|
-
constraint_func: bool=False):
|
|
418
|
+
constraint_func: bool=False) -> Callable:
|
|
285
419
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
286
420
|
samples the next state.
|
|
287
421
|
|
|
288
|
-
The
|
|
289
|
-
model_params), where:
|
|
422
|
+
The arguments of the returned function is:
|
|
290
423
|
- key is the PRNG key
|
|
291
424
|
- actions is the dict of action tensors
|
|
292
425
|
- subs is the dict of current pvar value tensors
|
|
293
426
|
- model_params is a dict of parameters for the relaxed model.
|
|
294
|
-
|
|
427
|
+
|
|
428
|
+
The returned value of the function is:
|
|
429
|
+
- subs is the returned next epoch fluent values
|
|
430
|
+
- log includes all the auxiliary information about constraints
|
|
431
|
+
satisfied, errors, etc.
|
|
432
|
+
|
|
295
433
|
constraint_func provides the option to compile nonlinear constraints:
|
|
296
434
|
|
|
297
435
|
1. f(s, a) ?? g(s, a)
|
|
@@ -361,6 +499,10 @@ class JaxRDDLCompiler:
|
|
|
361
499
|
reward, key, err = reward_fn(subs, model_params, key)
|
|
362
500
|
errors |= err
|
|
363
501
|
|
|
502
|
+
# calculate fluent values
|
|
503
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
504
|
+
if name not in rddl.non_fluents}
|
|
505
|
+
|
|
364
506
|
# set the next state to the current state
|
|
365
507
|
for (state, next_state) in rddl.next_state.items():
|
|
366
508
|
subs[state] = subs[next_state]
|
|
@@ -383,8 +525,7 @@ class JaxRDDLCompiler:
|
|
|
383
525
|
|
|
384
526
|
# prepare the return value
|
|
385
527
|
log = {
|
|
386
|
-
'
|
|
387
|
-
'action': actions,
|
|
528
|
+
'fluents': fluents,
|
|
388
529
|
'reward': reward,
|
|
389
530
|
'error': errors,
|
|
390
531
|
'precondition': precond_check,
|
|
@@ -395,7 +536,7 @@ class JaxRDDLCompiler:
|
|
|
395
536
|
log['inequalities'] = inequalities
|
|
396
537
|
log['equalities'] = equalities
|
|
397
538
|
|
|
398
|
-
return log
|
|
539
|
+
return subs, log
|
|
399
540
|
|
|
400
541
|
return _jax_wrapped_single_step
|
|
401
542
|
|
|
@@ -403,18 +544,28 @@ class JaxRDDLCompiler:
|
|
|
403
544
|
n_steps: int,
|
|
404
545
|
n_batch: int,
|
|
405
546
|
check_constraints: bool=False,
|
|
406
|
-
constraint_func: bool=False):
|
|
547
|
+
constraint_func: bool=False) -> Callable:
|
|
407
548
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
408
549
|
samples trajectories with a fixed horizon from a policy.
|
|
409
550
|
|
|
410
|
-
The
|
|
411
|
-
|
|
551
|
+
The arguments of the returned function is:
|
|
552
|
+
- key is the PRNG key (used by a stochastic policy)
|
|
553
|
+
- policy_params is a pytree of trainable policy weights
|
|
554
|
+
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
555
|
+
- subs is the dictionary of current fluent tensor values
|
|
556
|
+
- model_params is a dict of model hyperparameters.
|
|
557
|
+
|
|
558
|
+
The returned value of the returned function is:
|
|
559
|
+
- log is the dictionary of all trajectory information, including
|
|
560
|
+
constraints that were satisfied, errors, etc.
|
|
561
|
+
|
|
562
|
+
The arguments of the policy function is:
|
|
412
563
|
- key is the PRNG key (used by a stochastic policy)
|
|
413
564
|
- params is a pytree of trainable policy weights
|
|
414
565
|
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
415
566
|
- step is the time index of the decision in the current rollout
|
|
416
567
|
- states is a dict of tensors for the current observation.
|
|
417
|
-
|
|
568
|
+
|
|
418
569
|
:param policy: a Jax compiled function for the policy as described above
|
|
419
570
|
decision epoch, state dict, and an RNG key and returns an action dict
|
|
420
571
|
:param n_steps: the rollout horizon
|
|
@@ -428,27 +579,32 @@ class JaxRDDLCompiler:
|
|
|
428
579
|
rddl = self.rddl
|
|
429
580
|
jax_step_fn = self.compile_transition(check_constraints, constraint_func)
|
|
430
581
|
|
|
582
|
+
# for POMDP only observ-fluents are assumed visible to the policy
|
|
583
|
+
if rddl.observ_fluents:
|
|
584
|
+
observed_vars = rddl.observ_fluents
|
|
585
|
+
else:
|
|
586
|
+
observed_vars = rddl.state_fluents
|
|
587
|
+
|
|
431
588
|
# evaluate the step from the policy
|
|
432
589
|
def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
|
|
433
590
|
step, subs, model_params):
|
|
434
591
|
states = {var: values
|
|
435
592
|
for (var, values) in subs.items()
|
|
436
|
-
if
|
|
593
|
+
if var in observed_vars}
|
|
437
594
|
actions = policy(key, policy_params, hyperparams, step, states)
|
|
438
595
|
key, subkey = random.split(key)
|
|
439
|
-
log = jax_step_fn(subkey, actions, subs, model_params)
|
|
440
|
-
return log
|
|
596
|
+
subs, log = jax_step_fn(subkey, actions, subs, model_params)
|
|
597
|
+
return subs, log
|
|
441
598
|
|
|
442
599
|
# do a batched step update from the policy
|
|
443
600
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
444
601
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
445
602
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
446
603
|
keys = jnp.asarray(subkeys)
|
|
447
|
-
log = jax.vmap(
|
|
604
|
+
subs, log = jax.vmap(
|
|
448
605
|
_jax_wrapped_single_step_policy,
|
|
449
606
|
in_axes=(0, None, None, None, 0, None)
|
|
450
607
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
451
|
-
subs = log['pvar']
|
|
452
608
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
453
609
|
return carry, log
|
|
454
610
|
|
|
@@ -467,7 +623,7 @@ class JaxRDDLCompiler:
|
|
|
467
623
|
# error checks
|
|
468
624
|
# ===========================================================================
|
|
469
625
|
|
|
470
|
-
def print_jax(self) -> Dict[str,
|
|
626
|
+
def print_jax(self) -> Dict[str, Any]:
|
|
471
627
|
'''Returns a dictionary containing the string representations of all
|
|
472
628
|
Jax compiled expressions from the RDDL file.
|
|
473
629
|
'''
|
|
@@ -564,7 +720,7 @@ class JaxRDDLCompiler:
|
|
|
564
720
|
}
|
|
565
721
|
|
|
566
722
|
@staticmethod
|
|
567
|
-
def get_error_codes(error):
|
|
723
|
+
def get_error_codes(error: int) -> List[int]:
|
|
568
724
|
'''Given a compacted integer error flag from the execution of Jax, and
|
|
569
725
|
decomposes it into individual error codes.
|
|
570
726
|
'''
|
|
@@ -573,7 +729,7 @@ class JaxRDDLCompiler:
|
|
|
573
729
|
return errors
|
|
574
730
|
|
|
575
731
|
@staticmethod
|
|
576
|
-
def get_error_messages(error):
|
|
732
|
+
def get_error_messages(error: int) -> List[str]:
|
|
577
733
|
'''Given a compacted integer error flag from the execution of Jax, and
|
|
578
734
|
decomposes it into error strings.
|
|
579
735
|
'''
|
|
@@ -586,28 +742,40 @@ class JaxRDDLCompiler:
|
|
|
586
742
|
# ===========================================================================
|
|
587
743
|
|
|
588
744
|
def _unwrap(self, op, expr_id, info):
|
|
589
|
-
sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
|
|
590
745
|
jax_op, name = op, None
|
|
746
|
+
model_params, relaxed_list = info
|
|
591
747
|
if isinstance(op, tuple):
|
|
592
748
|
jax_op, param = op
|
|
593
749
|
if param is not None:
|
|
594
750
|
tags, values = param
|
|
751
|
+
sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
|
|
595
752
|
if isinstance(tags, tuple):
|
|
596
753
|
name = sep.join(tags)
|
|
597
754
|
else:
|
|
598
755
|
name = str(tags)
|
|
599
756
|
name = f'{name}{sep}{expr_id}'
|
|
600
|
-
if name in
|
|
601
|
-
raise
|
|
602
|
-
|
|
757
|
+
if name in model_params:
|
|
758
|
+
raise RuntimeError(
|
|
759
|
+
f'Internal error: model parameter {name} is already defined.')
|
|
760
|
+
model_params[name] = (values, tags, expr_id, jax_op.__name__)
|
|
761
|
+
relaxed_list.append((param, expr_id, jax_op.__name__))
|
|
603
762
|
return jax_op, name
|
|
604
763
|
|
|
605
|
-
def
|
|
606
|
-
'''Returns a
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
764
|
+
def summarize_model_relaxations(self) -> str:
|
|
765
|
+
'''Returns a string of information about model relaxations in the
|
|
766
|
+
compiled model.'''
|
|
767
|
+
occurence_by_type = {}
|
|
768
|
+
for (_, expr_id, jax_op) in self.relaxations:
|
|
769
|
+
etype = self.traced.lookup(expr_id).etype
|
|
770
|
+
source = f'{etype[1]} ({etype[0]})'
|
|
771
|
+
sub = f'{source:<30} --> {jax_op}'
|
|
772
|
+
occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
|
|
773
|
+
col = "{:<80} {:<10}\n"
|
|
774
|
+
table = col.format('Substitution', 'Count')
|
|
775
|
+
for (sub, occurs) in occurence_by_type.items():
|
|
776
|
+
table += col.format(sub, occurs)
|
|
777
|
+
return table
|
|
778
|
+
|
|
611
779
|
# ===========================================================================
|
|
612
780
|
# expression compilation
|
|
613
781
|
# ===========================================================================
|
|
@@ -640,7 +808,8 @@ class JaxRDDLCompiler:
|
|
|
640
808
|
raise RDDLNotImplementedError(
|
|
641
809
|
f'Internal error: expression type {expr} is not supported.\n' +
|
|
642
810
|
print_stack_trace(expr))
|
|
643
|
-
|
|
811
|
+
|
|
812
|
+
# force type cast of tensor as required by caller
|
|
644
813
|
if dtype is not None:
|
|
645
814
|
jax_expr = self._jax_cast(jax_expr, dtype)
|
|
646
815
|
|
|
@@ -660,6 +829,17 @@ class JaxRDDLCompiler:
|
|
|
660
829
|
|
|
661
830
|
return _jax_wrapped_cast
|
|
662
831
|
|
|
832
|
+
def _fix_dtype(self, value):
|
|
833
|
+
dtype = jnp.atleast_1d(value).dtype
|
|
834
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
835
|
+
return self.INT
|
|
836
|
+
elif jnp.issubdtype(dtype, jnp.floating):
|
|
837
|
+
return self.REAL
|
|
838
|
+
elif jnp.issubdtype(dtype, jnp.bool_) or jnp.issubdtype(dtype, bool):
|
|
839
|
+
return bool
|
|
840
|
+
else:
|
|
841
|
+
raise TypeError(f'Invalid type {dtype} of {value}.')
|
|
842
|
+
|
|
663
843
|
# ===========================================================================
|
|
664
844
|
# leaves
|
|
665
845
|
# ===========================================================================
|
|
@@ -669,7 +849,7 @@ class JaxRDDLCompiler:
|
|
|
669
849
|
cached_value = self.traced.cached_sim_info(expr)
|
|
670
850
|
|
|
671
851
|
def _jax_wrapped_constant(x, params, key):
|
|
672
|
-
sample = jnp.asarray(cached_value)
|
|
852
|
+
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
673
853
|
return sample, key, NORMAL
|
|
674
854
|
|
|
675
855
|
return _jax_wrapped_constant
|
|
@@ -693,7 +873,7 @@ class JaxRDDLCompiler:
|
|
|
693
873
|
cached_value = cached_info
|
|
694
874
|
|
|
695
875
|
def _jax_wrapped_object(x, params, key):
|
|
696
|
-
sample = jnp.asarray(cached_value)
|
|
876
|
+
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
697
877
|
return sample, key, NORMAL
|
|
698
878
|
|
|
699
879
|
return _jax_wrapped_object
|
|
@@ -702,7 +882,8 @@ class JaxRDDLCompiler:
|
|
|
702
882
|
elif cached_info is None:
|
|
703
883
|
|
|
704
884
|
def _jax_wrapped_pvar_scalar(x, params, key):
|
|
705
|
-
|
|
885
|
+
value = x[var]
|
|
886
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
706
887
|
return sample, key, NORMAL
|
|
707
888
|
|
|
708
889
|
return _jax_wrapped_pvar_scalar
|
|
@@ -721,7 +902,8 @@ class JaxRDDLCompiler:
|
|
|
721
902
|
|
|
722
903
|
def _jax_wrapped_pvar_tensor_nested(x, params, key):
|
|
723
904
|
error = NORMAL
|
|
724
|
-
|
|
905
|
+
value = x[var]
|
|
906
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
725
907
|
new_slices = [None] * len(jax_nested_expr)
|
|
726
908
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
727
909
|
new_slices[i], key, err = jax_expr(x, params, key)
|
|
@@ -736,7 +918,8 @@ class JaxRDDLCompiler:
|
|
|
736
918
|
else:
|
|
737
919
|
|
|
738
920
|
def _jax_wrapped_pvar_tensor_non_nested(x, params, key):
|
|
739
|
-
|
|
921
|
+
value = x[var]
|
|
922
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
740
923
|
if slices:
|
|
741
924
|
sample = sample[slices]
|
|
742
925
|
if axis:
|
|
@@ -795,16 +978,23 @@ class JaxRDDLCompiler:
|
|
|
795
978
|
|
|
796
979
|
def _jax_arithmetic(self, expr, info):
|
|
797
980
|
_, op = expr.etype
|
|
798
|
-
|
|
981
|
+
|
|
982
|
+
# if expression is non-fluent, always use the exact operation
|
|
983
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
984
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
|
|
985
|
+
negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
986
|
+
else:
|
|
987
|
+
valid_ops = self.ARITHMETIC_OPS
|
|
988
|
+
negative_op = self.NEGATIVE
|
|
799
989
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
800
|
-
|
|
990
|
+
|
|
991
|
+
# recursively compile arguments
|
|
801
992
|
args = expr.args
|
|
802
993
|
n = len(args)
|
|
803
|
-
|
|
804
994
|
if n == 1 and op == '-':
|
|
805
995
|
arg, = args
|
|
806
996
|
jax_expr = self._jax(arg, info)
|
|
807
|
-
jax_op, jax_param = self._unwrap(
|
|
997
|
+
jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
|
|
808
998
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
809
999
|
|
|
810
1000
|
elif n == 2:
|
|
@@ -819,29 +1009,42 @@ class JaxRDDLCompiler:
|
|
|
819
1009
|
|
|
820
1010
|
def _jax_relational(self, expr, info):
|
|
821
1011
|
_, op = expr.etype
|
|
822
|
-
|
|
1012
|
+
|
|
1013
|
+
# if expression is non-fluent, always use the exact operation
|
|
1014
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1015
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
|
|
1016
|
+
else:
|
|
1017
|
+
valid_ops = self.RELATIONAL_OPS
|
|
823
1018
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
824
|
-
|
|
1019
|
+
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
825
1020
|
|
|
1021
|
+
# recursively compile arguments
|
|
1022
|
+
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
826
1023
|
lhs, rhs = expr.args
|
|
827
1024
|
jax_lhs = self._jax(lhs, info)
|
|
828
1025
|
jax_rhs = self._jax(rhs, info)
|
|
829
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
830
1026
|
return self._jax_binary(
|
|
831
1027
|
jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
832
1028
|
|
|
833
1029
|
def _jax_logical(self, expr, info):
|
|
834
1030
|
_, op = expr.etype
|
|
835
|
-
|
|
1031
|
+
|
|
1032
|
+
# if expression is non-fluent, always use the exact operation
|
|
1033
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1034
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
|
|
1035
|
+
logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
1036
|
+
else:
|
|
1037
|
+
valid_ops = self.LOGICAL_OPS
|
|
1038
|
+
logical_not_op = self.LOGICAL_NOT
|
|
836
1039
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
837
1040
|
|
|
1041
|
+
# recursively compile arguments
|
|
838
1042
|
args = expr.args
|
|
839
|
-
n = len(args)
|
|
840
|
-
|
|
1043
|
+
n = len(args)
|
|
841
1044
|
if n == 1 and op == '~':
|
|
842
1045
|
arg, = args
|
|
843
1046
|
jax_expr = self._jax(arg, info)
|
|
844
|
-
jax_op, jax_param = self._unwrap(
|
|
1047
|
+
jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
|
|
845
1048
|
return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
|
|
846
1049
|
|
|
847
1050
|
elif n == 2:
|
|
@@ -856,17 +1059,21 @@ class JaxRDDLCompiler:
|
|
|
856
1059
|
|
|
857
1060
|
def _jax_aggregation(self, expr, info):
|
|
858
1061
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
859
|
-
|
|
860
1062
|
_, op = expr.etype
|
|
861
|
-
|
|
1063
|
+
|
|
1064
|
+
# if expression is non-fluent, always use the exact operation
|
|
1065
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1066
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
|
|
1067
|
+
else:
|
|
1068
|
+
valid_ops = self.AGGREGATION_OPS
|
|
862
1069
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
863
|
-
|
|
1070
|
+
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
864
1071
|
|
|
1072
|
+
# recursively compile arguments
|
|
1073
|
+
is_floating = op not in self.AGGREGATION_BOOL
|
|
865
1074
|
* _, arg = expr.args
|
|
866
|
-
_, axes = self.traced.cached_sim_info(expr)
|
|
867
|
-
|
|
1075
|
+
_, axes = self.traced.cached_sim_info(expr)
|
|
868
1076
|
jax_expr = self._jax(arg, info)
|
|
869
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
870
1077
|
|
|
871
1078
|
def _jax_wrapped_aggregation(x, params, key):
|
|
872
1079
|
sample, key, err = jax_expr(x, params, key)
|
|
@@ -884,21 +1091,28 @@ class JaxRDDLCompiler:
|
|
|
884
1091
|
def _jax_functional(self, expr, info):
|
|
885
1092
|
_, op = expr.etype
|
|
886
1093
|
|
|
887
|
-
#
|
|
888
|
-
if
|
|
1094
|
+
# if expression is non-fluent, always use the exact operation
|
|
1095
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1096
|
+
unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
|
|
1097
|
+
binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
|
|
1098
|
+
else:
|
|
1099
|
+
unary_ops = self.KNOWN_UNARY
|
|
1100
|
+
binary_ops = self.KNOWN_BINARY
|
|
1101
|
+
|
|
1102
|
+
# recursively compile arguments
|
|
1103
|
+
if op in unary_ops:
|
|
889
1104
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
890
1105
|
arg, = expr.args
|
|
891
1106
|
jax_expr = self._jax(arg, info)
|
|
892
|
-
jax_op, jax_param = self._unwrap(
|
|
1107
|
+
jax_op, jax_param = self._unwrap(unary_ops[op], expr.id, info)
|
|
893
1108
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
894
1109
|
|
|
895
|
-
|
|
896
|
-
elif op in self.KNOWN_BINARY:
|
|
1110
|
+
elif op in binary_ops:
|
|
897
1111
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
898
1112
|
lhs, rhs = expr.args
|
|
899
1113
|
jax_lhs = self._jax(lhs, info)
|
|
900
1114
|
jax_rhs = self._jax(rhs, info)
|
|
901
|
-
jax_op, jax_param = self._unwrap(
|
|
1115
|
+
jax_op, jax_param = self._unwrap(binary_ops[op], expr.id, info)
|
|
902
1116
|
return self._jax_binary(
|
|
903
1117
|
jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
904
1118
|
|
|
@@ -921,19 +1135,19 @@ class JaxRDDLCompiler:
|
|
|
921
1135
|
f'Control operator {op} is not supported.\n' +
|
|
922
1136
|
print_stack_trace(expr))
|
|
923
1137
|
|
|
924
|
-
def _jax_if_helper(self):
|
|
925
|
-
|
|
926
|
-
def _jax_wrapped_if_calc_exact(c, a, b, param):
|
|
927
|
-
return jnp.where(c, a, b)
|
|
928
|
-
|
|
929
|
-
return _jax_wrapped_if_calc_exact
|
|
930
|
-
|
|
931
1138
|
def _jax_if(self, expr, info):
|
|
932
1139
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
933
1140
|
JaxRDDLCompiler._check_num_args(expr, 3)
|
|
934
|
-
|
|
1141
|
+
pred, if_true, if_false = expr.args
|
|
1142
|
+
|
|
1143
|
+
# if predicate is non-fluent, always use the exact operation
|
|
1144
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1145
|
+
if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
|
|
1146
|
+
else:
|
|
1147
|
+
if_op = self.IF_HELPER
|
|
1148
|
+
jax_if, jax_param = self._unwrap(if_op, expr.id, info)
|
|
935
1149
|
|
|
936
|
-
|
|
1150
|
+
# recursively compile arguments
|
|
937
1151
|
jax_pred = self._jax(pred, info)
|
|
938
1152
|
jax_true = self._jax(if_true, info)
|
|
939
1153
|
jax_false = self._jax(if_false, info)
|
|
@@ -951,23 +1165,20 @@ class JaxRDDLCompiler:
|
|
|
951
1165
|
|
|
952
1166
|
return _jax_wrapped_if_then_else
|
|
953
1167
|
|
|
954
|
-
def _jax_switch_helper(self):
|
|
955
|
-
|
|
956
|
-
def _jax_wrapped_switch_calc_exact(pred, cases, param):
|
|
957
|
-
pred = pred[jnp.newaxis, ...]
|
|
958
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
959
|
-
assert sample.shape[0] == 1
|
|
960
|
-
return sample[0, ...]
|
|
961
|
-
|
|
962
|
-
return _jax_wrapped_switch_calc_exact
|
|
963
|
-
|
|
964
1168
|
def _jax_switch(self, expr, info):
|
|
965
|
-
|
|
1169
|
+
|
|
1170
|
+
# if expression is non-fluent, always use the exact operation
|
|
1171
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1172
|
+
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1173
|
+
else:
|
|
1174
|
+
switch_op = self.SWITCH_HELPER
|
|
1175
|
+
jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
|
|
1176
|
+
|
|
1177
|
+
# recursively compile predicate
|
|
1178
|
+
pred, *_ = expr.args
|
|
966
1179
|
jax_pred = self._jax(pred, info)
|
|
967
|
-
jax_switch, jax_param = self._unwrap(
|
|
968
|
-
self._jax_switch_helper(), expr.id, info)
|
|
969
1180
|
|
|
970
|
-
#
|
|
1181
|
+
# recursively compile cases
|
|
971
1182
|
cases, default = self.traced.cached_sim_info(expr)
|
|
972
1183
|
jax_default = None if default is None else self._jax(default, info)
|
|
973
1184
|
jax_cases = [(jax_default if _case is None else self._jax(_case, info))
|
|
@@ -983,7 +1194,8 @@ class JaxRDDLCompiler:
|
|
|
983
1194
|
for (i, jax_case) in enumerate(jax_cases):
|
|
984
1195
|
sample_cases[i], key, err_case = jax_case(x, params, key)
|
|
985
1196
|
err |= err_case
|
|
986
|
-
sample_cases = jnp.asarray(
|
|
1197
|
+
sample_cases = jnp.asarray(
|
|
1198
|
+
sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
987
1199
|
|
|
988
1200
|
# predicate (enum) is an integer - use it to extract from case array
|
|
989
1201
|
param = params.get(jax_param, None)
|
|
@@ -1179,30 +1391,28 @@ class JaxRDDLCompiler:
|
|
|
1179
1391
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1180
1392
|
key, subkey = random.split(key)
|
|
1181
1393
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1182
|
-
sample = scale * jnp.power(-jnp.
|
|
1394
|
+
sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
|
|
1183
1395
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1184
1396
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1185
1397
|
return sample, key, err
|
|
1186
1398
|
|
|
1187
1399
|
return _jax_wrapped_distribution_weibull
|
|
1188
1400
|
|
|
1189
|
-
def _jax_bernoulli_helper(self):
|
|
1190
|
-
|
|
1191
|
-
def _jax_wrapped_calc_bernoulli_exact(key, prob, param):
|
|
1192
|
-
return random.bernoulli(key, prob)
|
|
1193
|
-
|
|
1194
|
-
return _jax_wrapped_calc_bernoulli_exact
|
|
1195
|
-
|
|
1196
1401
|
def _jax_bernoulli(self, expr, info):
|
|
1197
1402
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1198
1403
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1199
|
-
jax_bern, jax_param = self._unwrap(
|
|
1200
|
-
self._jax_bernoulli_helper(), expr.id, info)
|
|
1201
|
-
|
|
1202
1404
|
arg_prob, = expr.args
|
|
1405
|
+
|
|
1406
|
+
# if probability is non-fluent, always use the exact operation
|
|
1407
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1408
|
+
bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
1409
|
+
else:
|
|
1410
|
+
bern_op = self.BERNOULLI_HELPER
|
|
1411
|
+
jax_bern, jax_param = self._unwrap(bern_op, expr.id, info)
|
|
1412
|
+
|
|
1413
|
+
# recursively compile arguments
|
|
1203
1414
|
jax_prob = self._jax(arg_prob, info)
|
|
1204
1415
|
|
|
1205
|
-
# uses the implicit JAX subroutine
|
|
1206
1416
|
def _jax_wrapped_distribution_bernoulli(x, params, key):
|
|
1207
1417
|
prob, key, err = jax_prob(x, params, key)
|
|
1208
1418
|
key, subkey = random.split(key)
|
|
@@ -1266,8 +1476,8 @@ class JaxRDDLCompiler:
|
|
|
1266
1476
|
def _jax_wrapped_distribution_binomial(x, params, key):
|
|
1267
1477
|
trials, key, err2 = jax_trials(x, params, key)
|
|
1268
1478
|
prob, key, err1 = jax_prob(x, params, key)
|
|
1269
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1270
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1479
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1480
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1271
1481
|
key, subkey = random.split(key)
|
|
1272
1482
|
dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
|
|
1273
1483
|
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
@@ -1290,11 +1500,10 @@ class JaxRDDLCompiler:
|
|
|
1290
1500
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1291
1501
|
trials, key, err2 = jax_trials(x, params, key)
|
|
1292
1502
|
prob, key, err1 = jax_prob(x, params, key)
|
|
1293
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1294
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1503
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1504
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1295
1505
|
key, subkey = random.split(key)
|
|
1296
|
-
dist = tfp.distributions.NegativeBinomial(
|
|
1297
|
-
total_count=trials, probs=prob)
|
|
1506
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
|
|
1298
1507
|
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
1299
1508
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1300
1509
|
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
@@ -1316,7 +1525,7 @@ class JaxRDDLCompiler:
|
|
|
1316
1525
|
shape, key, err1 = jax_shape(x, params, key)
|
|
1317
1526
|
rate, key, err2 = jax_rate(x, params, key)
|
|
1318
1527
|
key, subkey = random.split(key)
|
|
1319
|
-
sample = random.beta(key=subkey, a=shape, b=rate)
|
|
1528
|
+
sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
|
|
1320
1529
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
|
|
1321
1530
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1322
1531
|
return sample, key, err
|
|
@@ -1325,23 +1534,35 @@ class JaxRDDLCompiler:
|
|
|
1325
1534
|
|
|
1326
1535
|
def _jax_geometric(self, expr, info):
|
|
1327
1536
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1328
|
-
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1329
|
-
|
|
1537
|
+
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1330
1538
|
arg_prob, = expr.args
|
|
1331
1539
|
jax_prob = self._jax(arg_prob, info)
|
|
1332
|
-
floor_op, jax_param = self._unwrap(
|
|
1333
|
-
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1334
1540
|
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
prob
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1541
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1542
|
+
|
|
1543
|
+
# prob is non-fluent: do not reparameterize
|
|
1544
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1545
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1546
|
+
key, subkey = random.split(key)
|
|
1547
|
+
sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
|
|
1548
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1549
|
+
err |= (out_of_bounds * ERR)
|
|
1550
|
+
return sample, key, err
|
|
1551
|
+
|
|
1552
|
+
else:
|
|
1553
|
+
floor_op, jax_param = self._unwrap(
|
|
1554
|
+
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1555
|
+
|
|
1556
|
+
# reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
|
|
1557
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1558
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1559
|
+
key, subkey = random.split(key)
|
|
1560
|
+
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1561
|
+
param = params.get(jax_param, None)
|
|
1562
|
+
sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
1563
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1564
|
+
err |= (out_of_bounds * ERR)
|
|
1565
|
+
return sample, key, err
|
|
1345
1566
|
|
|
1346
1567
|
return _jax_wrapped_distribution_geometric
|
|
1347
1568
|
|
|
@@ -1359,7 +1580,7 @@ class JaxRDDLCompiler:
|
|
|
1359
1580
|
shape, key, err1 = jax_shape(x, params, key)
|
|
1360
1581
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1361
1582
|
key, subkey = random.split(key)
|
|
1362
|
-
sample = scale * random.pareto(key=subkey, b=shape)
|
|
1583
|
+
sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
|
|
1363
1584
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1364
1585
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1365
1586
|
return sample, key, err
|
|
@@ -1377,7 +1598,8 @@ class JaxRDDLCompiler:
|
|
|
1377
1598
|
def _jax_wrapped_distribution_t(x, params, key):
|
|
1378
1599
|
df, key, err = jax_df(x, params, key)
|
|
1379
1600
|
key, subkey = random.split(key)
|
|
1380
|
-
sample = random.t(
|
|
1601
|
+
sample = random.t(
|
|
1602
|
+
key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1381
1603
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1382
1604
|
err |= (out_of_bounds * ERR)
|
|
1383
1605
|
return sample, key, err
|
|
@@ -1464,7 +1686,7 @@ class JaxRDDLCompiler:
|
|
|
1464
1686
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1465
1687
|
key, subkey = random.split(key)
|
|
1466
1688
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1467
|
-
sample = jnp.log(1.0 - jnp.
|
|
1689
|
+
sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
|
|
1468
1690
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1469
1691
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1470
1692
|
return sample, key, err
|
|
@@ -1516,25 +1738,21 @@ class JaxRDDLCompiler:
|
|
|
1516
1738
|
# random variables with enum support
|
|
1517
1739
|
# ===========================================================================
|
|
1518
1740
|
|
|
1519
|
-
def _jax_discrete_helper(self):
|
|
1520
|
-
|
|
1521
|
-
def _jax_wrapped_discrete_calc_exact(key, prob, param):
|
|
1522
|
-
logits = jnp.log(prob)
|
|
1523
|
-
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
1524
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1525
|
-
jnp.all(prob >= 0),
|
|
1526
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1527
|
-
return sample, out_of_bounds
|
|
1528
|
-
|
|
1529
|
-
return _jax_wrapped_discrete_calc_exact
|
|
1530
|
-
|
|
1531
1741
|
def _jax_discrete(self, expr, info, unnorm):
|
|
1532
1742
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
1533
1743
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1534
|
-
jax_discrete, jax_param = self._unwrap(
|
|
1535
|
-
self._jax_discrete_helper(), expr.id, info)
|
|
1536
|
-
|
|
1537
1744
|
ordered_args = self.traced.cached_sim_info(expr)
|
|
1745
|
+
|
|
1746
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1747
|
+
has_fluent_arg = any(self.traced.cached_is_fluent(arg)
|
|
1748
|
+
for arg in ordered_args)
|
|
1749
|
+
if self.compile_non_fluent_exact and not has_fluent_arg:
|
|
1750
|
+
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1751
|
+
else:
|
|
1752
|
+
discrete_op = self.DISCRETE_HELPER
|
|
1753
|
+
jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
|
|
1754
|
+
|
|
1755
|
+
# compile probability expressions
|
|
1538
1756
|
jax_probs = [self._jax(arg, info) for arg in ordered_args]
|
|
1539
1757
|
|
|
1540
1758
|
def _jax_wrapped_distribution_discrete(x, params, key):
|
|
@@ -1561,12 +1779,18 @@ class JaxRDDLCompiler:
|
|
|
1561
1779
|
|
|
1562
1780
|
def _jax_discrete_pvar(self, expr, info, unnorm):
|
|
1563
1781
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1564
|
-
JaxRDDLCompiler._check_num_args(expr,
|
|
1565
|
-
jax_discrete, jax_param = self._unwrap(
|
|
1566
|
-
self._jax_discrete_helper(), expr.id, info)
|
|
1567
|
-
|
|
1782
|
+
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1568
1783
|
_, args = expr.args
|
|
1569
1784
|
arg, = args
|
|
1785
|
+
|
|
1786
|
+
# if probabilities are non-fluent, then always sample exact
|
|
1787
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
|
|
1788
|
+
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1789
|
+
else:
|
|
1790
|
+
discrete_op = self.DISCRETE_HELPER
|
|
1791
|
+
jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
|
|
1792
|
+
|
|
1793
|
+
# compile probability function
|
|
1570
1794
|
jax_probs = self._jax(arg, info)
|
|
1571
1795
|
|
|
1572
1796
|
def _jax_wrapped_distribution_discrete_pvar(x, params, key):
|
|
@@ -1687,7 +1911,7 @@ class JaxRDDLCompiler:
|
|
|
1687
1911
|
out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
|
|
1688
1912
|
error |= (out_of_bounds * ERR)
|
|
1689
1913
|
key, subkey = random.split(key)
|
|
1690
|
-
Gamma = random.gamma(key=subkey, a=alpha)
|
|
1914
|
+
Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
|
|
1691
1915
|
sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
|
|
1692
1916
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1693
1917
|
return sample, key, error
|
|
@@ -1706,8 +1930,8 @@ class JaxRDDLCompiler:
|
|
|
1706
1930
|
def _jax_wrapped_distribution_multinomial(x, params, key):
|
|
1707
1931
|
trials, key, err1 = jax_trials(x, params, key)
|
|
1708
1932
|
prob, key, err2 = jax_prob(x, params, key)
|
|
1709
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1710
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1933
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1934
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1711
1935
|
key, subkey = random.split(key)
|
|
1712
1936
|
dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
|
|
1713
1937
|
sample = dist.sample(seed=subkey).astype(self.INT)
|