pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +444 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +965 -394
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +29 -15
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
- pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
- pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +3 -7
- pyRDDLGym_jax/examples/run_plan.py +10 -5
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +8 -3
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -4,7 +4,7 @@ import jax.numpy as jnp
|
|
|
4
4
|
import jax.random as random
|
|
5
5
|
import jax.scipy as scipy
|
|
6
6
|
import traceback
|
|
7
|
-
from typing import Callable, Dict, List
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
10
10
|
|
|
@@ -13,8 +13,8 @@ try:
|
|
|
13
13
|
from tensorflow_probability.substrates import jax as tfp
|
|
14
14
|
except Exception:
|
|
15
15
|
raise_warning('Failed to import tensorflow-probability: '
|
|
16
|
-
'compilation of some complex distributions
|
|
17
|
-
'red')
|
|
16
|
+
'compilation of some complex distributions '
|
|
17
|
+
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
18
18
|
traceback.print_exc()
|
|
19
19
|
tfp = None
|
|
20
20
|
|
|
@@ -32,6 +32,98 @@ from pyRDDLGym.core.debug.logger import Logger
|
|
|
32
32
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
# ===========================================================================
|
|
36
|
+
# EXACT RDDL TO JAX COMPILATION RULES
|
|
37
|
+
# ===========================================================================
|
|
38
|
+
|
|
39
|
+
def _function_unary_exact_named(op, name):
|
|
40
|
+
|
|
41
|
+
def _jax_wrapped_unary_fn_exact(x, param):
|
|
42
|
+
return op(x)
|
|
43
|
+
|
|
44
|
+
return _jax_wrapped_unary_fn_exact
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _function_unary_exact_named_gamma():
|
|
48
|
+
|
|
49
|
+
def _jax_wrapped_unary_gamma_exact(x, param):
|
|
50
|
+
return jnp.exp(scipy.special.gammaln(x))
|
|
51
|
+
|
|
52
|
+
return _jax_wrapped_unary_gamma_exact
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _function_binary_exact_named(op, name):
|
|
56
|
+
|
|
57
|
+
def _jax_wrapped_binary_fn_exact(x, y, param):
|
|
58
|
+
return op(x, y)
|
|
59
|
+
|
|
60
|
+
return _jax_wrapped_binary_fn_exact
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _function_binary_exact_named_implies():
|
|
64
|
+
|
|
65
|
+
def _jax_wrapped_binary_implies_exact(x, y, param):
|
|
66
|
+
return jnp.logical_or(jnp.logical_not(x), y)
|
|
67
|
+
|
|
68
|
+
return _jax_wrapped_binary_implies_exact
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _function_binary_exact_named_log():
|
|
72
|
+
|
|
73
|
+
def _jax_wrapped_binary_log_exact(x, y, param):
|
|
74
|
+
return jnp.log(x) / jnp.log(y)
|
|
75
|
+
|
|
76
|
+
return _jax_wrapped_binary_log_exact
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _function_aggregation_exact_named(op, name):
|
|
80
|
+
|
|
81
|
+
def _jax_wrapped_aggregation_fn_exact(x, axis, param):
|
|
82
|
+
return op(x, axis=axis)
|
|
83
|
+
|
|
84
|
+
return _jax_wrapped_aggregation_fn_exact
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _function_if_exact_named():
|
|
88
|
+
|
|
89
|
+
def _jax_wrapped_if_exact(c, a, b, param):
|
|
90
|
+
return jnp.where(c, a, b)
|
|
91
|
+
|
|
92
|
+
return _jax_wrapped_if_exact
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _function_switch_exact_named():
|
|
96
|
+
|
|
97
|
+
def _jax_wrapped_switch_exact(pred, cases, param):
|
|
98
|
+
pred = pred[jnp.newaxis, ...]
|
|
99
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
100
|
+
assert sample.shape[0] == 1
|
|
101
|
+
return sample[0, ...]
|
|
102
|
+
|
|
103
|
+
return _jax_wrapped_switch_exact
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _function_bernoulli_exact_named():
|
|
107
|
+
|
|
108
|
+
def _jax_wrapped_bernoulli_exact(key, prob, param):
|
|
109
|
+
return random.bernoulli(key, prob)
|
|
110
|
+
|
|
111
|
+
return _jax_wrapped_bernoulli_exact
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _function_discrete_exact_named():
|
|
115
|
+
|
|
116
|
+
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
117
|
+
logits = jnp.log(prob)
|
|
118
|
+
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
119
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
120
|
+
jnp.all(prob >= 0),
|
|
121
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
122
|
+
return sample, out_of_bounds
|
|
123
|
+
|
|
124
|
+
return _jax_wrapped_discrete_exact
|
|
125
|
+
|
|
126
|
+
|
|
35
127
|
class JaxRDDLCompiler:
|
|
36
128
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
37
129
|
All operations are identical to their numpy equivalents.
|
|
@@ -39,10 +131,97 @@ class JaxRDDLCompiler:
|
|
|
39
131
|
|
|
40
132
|
MODEL_PARAM_TAG_SEPARATOR = '___'
|
|
41
133
|
|
|
134
|
+
# ===========================================================================
|
|
135
|
+
# EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
|
|
136
|
+
# ===========================================================================
|
|
137
|
+
|
|
138
|
+
EXACT_RDDL_TO_JAX_NEGATIVE = _function_unary_exact_named(jnp.negative, 'negative')
|
|
139
|
+
|
|
140
|
+
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
141
|
+
'+': _function_binary_exact_named(jnp.add, 'add'),
|
|
142
|
+
'-': _function_binary_exact_named(jnp.subtract, 'subtract'),
|
|
143
|
+
'*': _function_binary_exact_named(jnp.multiply, 'multiply'),
|
|
144
|
+
'/': _function_binary_exact_named(jnp.divide, 'divide')
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
148
|
+
'>=': _function_binary_exact_named(jnp.greater_equal, 'greater_equal'),
|
|
149
|
+
'<=': _function_binary_exact_named(jnp.less_equal, 'less_equal'),
|
|
150
|
+
'<': _function_binary_exact_named(jnp.less, 'less'),
|
|
151
|
+
'>': _function_binary_exact_named(jnp.greater, 'greater'),
|
|
152
|
+
'==': _function_binary_exact_named(jnp.equal, 'equal'),
|
|
153
|
+
'~=': _function_binary_exact_named(jnp.not_equal, 'not_equal')
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
157
|
+
'^': _function_binary_exact_named(jnp.logical_and, 'and'),
|
|
158
|
+
'&': _function_binary_exact_named(jnp.logical_and, 'and'),
|
|
159
|
+
'|': _function_binary_exact_named(jnp.logical_or, 'or'),
|
|
160
|
+
'~': _function_binary_exact_named(jnp.logical_xor, 'xor'),
|
|
161
|
+
'=>': _function_binary_exact_named_implies(),
|
|
162
|
+
'<=>': _function_binary_exact_named(jnp.equal, 'iff')
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
EXACT_RDDL_TO_JAX_LOGICAL_NOT = _function_unary_exact_named(jnp.logical_not, 'not')
|
|
166
|
+
|
|
167
|
+
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
168
|
+
'sum': _function_aggregation_exact_named(jnp.sum, 'sum'),
|
|
169
|
+
'avg': _function_aggregation_exact_named(jnp.mean, 'avg'),
|
|
170
|
+
'prod': _function_aggregation_exact_named(jnp.prod, 'prod'),
|
|
171
|
+
'minimum': _function_aggregation_exact_named(jnp.min, 'minimum'),
|
|
172
|
+
'maximum': _function_aggregation_exact_named(jnp.max, 'maximum'),
|
|
173
|
+
'forall': _function_aggregation_exact_named(jnp.all, 'forall'),
|
|
174
|
+
'exists': _function_aggregation_exact_named(jnp.any, 'exists'),
|
|
175
|
+
'argmin': _function_aggregation_exact_named(jnp.argmin, 'argmin'),
|
|
176
|
+
'argmax': _function_aggregation_exact_named(jnp.argmax, 'argmax')
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
EXACT_RDDL_TO_JAX_UNARY = {
|
|
180
|
+
'abs': _function_unary_exact_named(jnp.abs, 'abs'),
|
|
181
|
+
'sgn': _function_unary_exact_named(jnp.sign, 'sgn'),
|
|
182
|
+
'round': _function_unary_exact_named(jnp.round, 'round'),
|
|
183
|
+
'floor': _function_unary_exact_named(jnp.floor, 'floor'),
|
|
184
|
+
'ceil': _function_unary_exact_named(jnp.ceil, 'ceil'),
|
|
185
|
+
'cos': _function_unary_exact_named(jnp.cos, 'cos'),
|
|
186
|
+
'sin': _function_unary_exact_named(jnp.sin, 'sin'),
|
|
187
|
+
'tan': _function_unary_exact_named(jnp.tan, 'tan'),
|
|
188
|
+
'acos': _function_unary_exact_named(jnp.arccos, 'acos'),
|
|
189
|
+
'asin': _function_unary_exact_named(jnp.arcsin, 'asin'),
|
|
190
|
+
'atan': _function_unary_exact_named(jnp.arctan, 'atan'),
|
|
191
|
+
'cosh': _function_unary_exact_named(jnp.cosh, 'cosh'),
|
|
192
|
+
'sinh': _function_unary_exact_named(jnp.sinh, 'sinh'),
|
|
193
|
+
'tanh': _function_unary_exact_named(jnp.tanh, 'tanh'),
|
|
194
|
+
'exp': _function_unary_exact_named(jnp.exp, 'exp'),
|
|
195
|
+
'ln': _function_unary_exact_named(jnp.log, 'ln'),
|
|
196
|
+
'sqrt': _function_unary_exact_named(jnp.sqrt, 'sqrt'),
|
|
197
|
+
'lngamma': _function_unary_exact_named(scipy.special.gammaln, 'lngamma'),
|
|
198
|
+
'gamma': _function_unary_exact_named_gamma()
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
EXACT_RDDL_TO_JAX_BINARY = {
|
|
202
|
+
'div': _function_binary_exact_named(jnp.floor_divide, 'div'),
|
|
203
|
+
'mod': _function_binary_exact_named(jnp.mod, 'mod'),
|
|
204
|
+
'fmod': _function_binary_exact_named(jnp.mod, 'fmod'),
|
|
205
|
+
'min': _function_binary_exact_named(jnp.minimum, 'min'),
|
|
206
|
+
'max': _function_binary_exact_named(jnp.maximum, 'max'),
|
|
207
|
+
'pow': _function_binary_exact_named(jnp.power, 'pow'),
|
|
208
|
+
'log': _function_binary_exact_named_log(),
|
|
209
|
+
'hypot': _function_binary_exact_named(jnp.hypot, 'hypot'),
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
|
|
213
|
+
|
|
214
|
+
EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
|
|
215
|
+
|
|
216
|
+
EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
|
|
217
|
+
|
|
218
|
+
EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
|
|
219
|
+
|
|
42
220
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
43
221
|
allow_synchronous_state: bool=True,
|
|
44
|
-
logger: Logger=None,
|
|
45
|
-
use64bit: bool=False
|
|
222
|
+
logger: Optional[Logger]=None,
|
|
223
|
+
use64bit: bool=False,
|
|
224
|
+
compile_non_fluent_exact: bool=True) -> None:
|
|
46
225
|
'''Creates a new RDDL to Jax compiler.
|
|
47
226
|
|
|
48
227
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -50,11 +229,14 @@ class JaxRDDLCompiler:
|
|
|
50
229
|
on each other
|
|
51
230
|
:param logger: to log information about compilation to file
|
|
52
231
|
:param use64bit: whether to use 64 bit arithmetic
|
|
232
|
+
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
233
|
+
are always compiled using exact JAX expressions.
|
|
53
234
|
'''
|
|
54
235
|
self.rddl = rddl
|
|
55
236
|
self.logger = logger
|
|
56
237
|
# jax.config.update('jax_log_compiles', True) # for testing ONLY
|
|
57
238
|
|
|
239
|
+
self.use64bit = use64bit
|
|
58
240
|
if use64bit:
|
|
59
241
|
self.INT = jnp.int64
|
|
60
242
|
self.REAL = jnp.float64
|
|
@@ -62,6 +244,7 @@ class JaxRDDLCompiler:
|
|
|
62
244
|
else:
|
|
63
245
|
self.INT = jnp.int32
|
|
64
246
|
self.REAL = jnp.float32
|
|
247
|
+
jax.config.update('jax_enable_x64', False)
|
|
65
248
|
self.ONE = jnp.asarray(1, dtype=self.INT)
|
|
66
249
|
self.JAX_TYPES = {
|
|
67
250
|
'int': self.INT,
|
|
@@ -70,17 +253,16 @@ class JaxRDDLCompiler:
|
|
|
70
253
|
}
|
|
71
254
|
|
|
72
255
|
# compile initial values
|
|
73
|
-
|
|
74
|
-
self.logger.clear()
|
|
75
|
-
initializer = RDDLValueInitializer(rddl, logger=self.logger)
|
|
256
|
+
initializer = RDDLValueInitializer(rddl)
|
|
76
257
|
self.init_values = initializer.initialize()
|
|
77
258
|
|
|
78
259
|
# compute dependency graph for CPFs and sort them by evaluation order
|
|
79
|
-
sorter = RDDLLevelAnalysis(
|
|
260
|
+
sorter = RDDLLevelAnalysis(
|
|
261
|
+
rddl, allow_synchronous_state=allow_synchronous_state)
|
|
80
262
|
self.levels = sorter.compute_levels()
|
|
81
263
|
|
|
82
264
|
# trace expressions to cache information to be used later
|
|
83
|
-
tracer = RDDLObjectsTracer(rddl,
|
|
265
|
+
tracer = RDDLObjectsTracer(rddl, cpf_levels=self.levels)
|
|
84
266
|
self.traced = tracer.trace()
|
|
85
267
|
|
|
86
268
|
# extract the box constraints on actions
|
|
@@ -92,92 +274,42 @@ class JaxRDDLCompiler:
|
|
|
92
274
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
93
275
|
self.constraints = constraints
|
|
94
276
|
|
|
95
|
-
# basic operations
|
|
96
|
-
self.
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
self.RELATIONAL_OPS = {
|
|
104
|
-
'>=': lambda x, y, param: jnp.greater_equal(x, y),
|
|
105
|
-
'<=': lambda x, y, param: jnp.less_equal(x, y),
|
|
106
|
-
'<': lambda x, y, param: jnp.less(x, y),
|
|
107
|
-
'>': lambda x, y, param: jnp.greater(x, y),
|
|
108
|
-
'==': lambda x, y, param: jnp.equal(x, y),
|
|
109
|
-
'~=': lambda x, y, param: jnp.not_equal(x, y)
|
|
110
|
-
}
|
|
111
|
-
self.LOGICAL_NOT = lambda x, param: jnp.logical_not(x)
|
|
112
|
-
self.LOGICAL_OPS = {
|
|
113
|
-
'^': lambda x, y, param: jnp.logical_and(x, y),
|
|
114
|
-
'&': lambda x, y, param: jnp.logical_and(x, y),
|
|
115
|
-
'|': lambda x, y, param: jnp.logical_or(x, y),
|
|
116
|
-
'~': lambda x, y, param: jnp.logical_xor(x, y),
|
|
117
|
-
'=>': lambda x, y, param: jnp.logical_or(jnp.logical_not(x), y),
|
|
118
|
-
'<=>': lambda x, y, param: jnp.equal(x, y)
|
|
119
|
-
}
|
|
120
|
-
self.AGGREGATION_OPS = {
|
|
121
|
-
'sum': lambda x, axis, param: jnp.sum(x, axis=axis),
|
|
122
|
-
'avg': lambda x, axis, param: jnp.mean(x, axis=axis),
|
|
123
|
-
'prod': lambda x, axis, param: jnp.prod(x, axis=axis),
|
|
124
|
-
'minimum': lambda x, axis, param: jnp.min(x, axis=axis),
|
|
125
|
-
'maximum': lambda x, axis, param: jnp.max(x, axis=axis),
|
|
126
|
-
'forall': lambda x, axis, param: jnp.all(x, axis=axis),
|
|
127
|
-
'exists': lambda x, axis, param: jnp.any(x, axis=axis),
|
|
128
|
-
'argmin': lambda x, axis, param: jnp.argmin(x, axis=axis),
|
|
129
|
-
'argmax': lambda x, axis, param: jnp.argmax(x, axis=axis)
|
|
130
|
-
}
|
|
277
|
+
# basic operations - these can be override in subclasses
|
|
278
|
+
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
279
|
+
self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
280
|
+
self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
|
|
281
|
+
self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
|
|
282
|
+
self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
283
|
+
self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
|
|
284
|
+
self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
|
|
131
285
|
self.AGGREGATION_BOOL = {'forall', 'exists'}
|
|
132
|
-
self.KNOWN_UNARY =
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
'sin': lambda x, param: jnp.sin(x),
|
|
140
|
-
'tan': lambda x, param: jnp.tan(x),
|
|
141
|
-
'acos': lambda x, param: jnp.arccos(x),
|
|
142
|
-
'asin': lambda x, param: jnp.arcsin(x),
|
|
143
|
-
'atan': lambda x, param: jnp.arctan(x),
|
|
144
|
-
'cosh': lambda x, param: jnp.cosh(x),
|
|
145
|
-
'sinh': lambda x, param: jnp.sinh(x),
|
|
146
|
-
'tanh': lambda x, param: jnp.tanh(x),
|
|
147
|
-
'exp': lambda x, param: jnp.exp(x),
|
|
148
|
-
'ln': lambda x, param: jnp.log(x),
|
|
149
|
-
'sqrt': lambda x, param: jnp.sqrt(x),
|
|
150
|
-
'lngamma': lambda x, param: scipy.special.gammaln(x),
|
|
151
|
-
'gamma': lambda x, param: jnp.exp(scipy.special.gammaln(x))
|
|
152
|
-
}
|
|
153
|
-
self.KNOWN_BINARY = {
|
|
154
|
-
'div': lambda x, y, param: jnp.floor_divide(x, y),
|
|
155
|
-
'mod': lambda x, y, param: jnp.mod(x, y),
|
|
156
|
-
'fmod': lambda x, y, param: jnp.mod(x, y),
|
|
157
|
-
'min': lambda x, y, param: jnp.minimum(x, y),
|
|
158
|
-
'max': lambda x, y, param: jnp.maximum(x, y),
|
|
159
|
-
'pow': lambda x, y, param: jnp.power(x, y),
|
|
160
|
-
'log': lambda x, y, param: jnp.log(x) / jnp.log(y),
|
|
161
|
-
'hypot': lambda x, y, param: jnp.hypot(x, y)
|
|
162
|
-
}
|
|
163
|
-
|
|
286
|
+
self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
|
|
287
|
+
self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
|
|
288
|
+
self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
|
|
289
|
+
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
290
|
+
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
291
|
+
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
292
|
+
|
|
164
293
|
# ===========================================================================
|
|
165
294
|
# main compilation subroutines
|
|
166
295
|
# ===========================================================================
|
|
167
296
|
|
|
168
|
-
def compile(self, log_jax_expr: bool=False) -> None:
|
|
297
|
+
def compile(self, log_jax_expr: bool=False, heading: str='') -> None:
|
|
169
298
|
'''Compiles the current RDDL into Jax expressions.
|
|
170
299
|
|
|
171
300
|
:param log_jax_expr: whether to pretty-print the compiled Jax functions
|
|
172
301
|
to the log file
|
|
302
|
+
:param heading: the heading to print before compilation information
|
|
173
303
|
'''
|
|
174
|
-
info = {}
|
|
304
|
+
info = ({}, [])
|
|
175
305
|
self.invariants = self._compile_constraints(self.rddl.invariants, info)
|
|
176
306
|
self.preconditions = self._compile_constraints(self.rddl.preconditions, info)
|
|
177
307
|
self.terminations = self._compile_constraints(self.rddl.terminations, info)
|
|
178
308
|
self.cpfs = self._compile_cpfs(info)
|
|
179
309
|
self.reward = self._compile_reward(info)
|
|
180
|
-
self.model_params =
|
|
310
|
+
self.model_params = {key: value
|
|
311
|
+
for (key, (value, *_)) in info[0].items()}
|
|
312
|
+
self.relaxations = info[1]
|
|
181
313
|
|
|
182
314
|
if log_jax_expr and self.logger is not None:
|
|
183
315
|
printed = self.print_jax()
|
|
@@ -189,6 +321,7 @@ class JaxRDDLCompiler:
|
|
|
189
321
|
printed_terminals = '\n\n'.join(v for v in printed['terminations'])
|
|
190
322
|
printed_params = '\n'.join(f'{k}: {v}' for (k, v) in info.items())
|
|
191
323
|
message = (
|
|
324
|
+
f'[info] {heading}\n'
|
|
192
325
|
f'[info] compiled JAX CPFs:\n\n'
|
|
193
326
|
f'{printed_cpfs}\n\n'
|
|
194
327
|
f'[info] compiled JAX reward:\n\n'
|
|
@@ -281,17 +414,21 @@ class JaxRDDLCompiler:
|
|
|
281
414
|
return jax_inequalities, jax_equalities
|
|
282
415
|
|
|
283
416
|
def compile_transition(self, check_constraints: bool=False,
|
|
284
|
-
constraint_func: bool=False):
|
|
417
|
+
constraint_func: bool=False) -> Callable:
|
|
285
418
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
286
419
|
samples the next state.
|
|
287
420
|
|
|
288
|
-
The
|
|
289
|
-
model_params), where:
|
|
421
|
+
The arguments of the returned function is:
|
|
290
422
|
- key is the PRNG key
|
|
291
423
|
- actions is the dict of action tensors
|
|
292
424
|
- subs is the dict of current pvar value tensors
|
|
293
425
|
- model_params is a dict of parameters for the relaxed model.
|
|
294
|
-
|
|
426
|
+
|
|
427
|
+
The returned value of the function is:
|
|
428
|
+
- subs is the returned next epoch fluent values
|
|
429
|
+
- log includes all the auxiliary information about constraints
|
|
430
|
+
satisfied, errors, etc.
|
|
431
|
+
|
|
295
432
|
constraint_func provides the option to compile nonlinear constraints:
|
|
296
433
|
|
|
297
434
|
1. f(s, a) ?? g(s, a)
|
|
@@ -361,6 +498,10 @@ class JaxRDDLCompiler:
|
|
|
361
498
|
reward, key, err = reward_fn(subs, model_params, key)
|
|
362
499
|
errors |= err
|
|
363
500
|
|
|
501
|
+
# calculate fluent values
|
|
502
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
503
|
+
if name not in rddl.non_fluents}
|
|
504
|
+
|
|
364
505
|
# set the next state to the current state
|
|
365
506
|
for (state, next_state) in rddl.next_state.items():
|
|
366
507
|
subs[state] = subs[next_state]
|
|
@@ -383,8 +524,7 @@ class JaxRDDLCompiler:
|
|
|
383
524
|
|
|
384
525
|
# prepare the return value
|
|
385
526
|
log = {
|
|
386
|
-
'
|
|
387
|
-
'action': actions,
|
|
527
|
+
'fluents': fluents,
|
|
388
528
|
'reward': reward,
|
|
389
529
|
'error': errors,
|
|
390
530
|
'precondition': precond_check,
|
|
@@ -395,7 +535,7 @@ class JaxRDDLCompiler:
|
|
|
395
535
|
log['inequalities'] = inequalities
|
|
396
536
|
log['equalities'] = equalities
|
|
397
537
|
|
|
398
|
-
return log
|
|
538
|
+
return subs, log
|
|
399
539
|
|
|
400
540
|
return _jax_wrapped_single_step
|
|
401
541
|
|
|
@@ -403,18 +543,28 @@ class JaxRDDLCompiler:
|
|
|
403
543
|
n_steps: int,
|
|
404
544
|
n_batch: int,
|
|
405
545
|
check_constraints: bool=False,
|
|
406
|
-
constraint_func: bool=False):
|
|
546
|
+
constraint_func: bool=False) -> Callable:
|
|
407
547
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
408
548
|
samples trajectories with a fixed horizon from a policy.
|
|
409
549
|
|
|
410
|
-
The
|
|
411
|
-
|
|
550
|
+
The arguments of the returned function is:
|
|
551
|
+
- key is the PRNG key (used by a stochastic policy)
|
|
552
|
+
- policy_params is a pytree of trainable policy weights
|
|
553
|
+
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
554
|
+
- subs is the dictionary of current fluent tensor values
|
|
555
|
+
- model_params is a dict of model hyperparameters.
|
|
556
|
+
|
|
557
|
+
The returned value of the returned function is:
|
|
558
|
+
- log is the dictionary of all trajectory information, including
|
|
559
|
+
constraints that were satisfied, errors, etc.
|
|
560
|
+
|
|
561
|
+
The arguments of the policy function is:
|
|
412
562
|
- key is the PRNG key (used by a stochastic policy)
|
|
413
563
|
- params is a pytree of trainable policy weights
|
|
414
564
|
- hyperparams is a pytree of (optional) fixed policy hyper-parameters
|
|
415
565
|
- step is the time index of the decision in the current rollout
|
|
416
566
|
- states is a dict of tensors for the current observation.
|
|
417
|
-
|
|
567
|
+
|
|
418
568
|
:param policy: a Jax compiled function for the policy as described above
|
|
419
569
|
decision epoch, state dict, and an RNG key and returns an action dict
|
|
420
570
|
:param n_steps: the rollout horizon
|
|
@@ -428,27 +578,32 @@ class JaxRDDLCompiler:
|
|
|
428
578
|
rddl = self.rddl
|
|
429
579
|
jax_step_fn = self.compile_transition(check_constraints, constraint_func)
|
|
430
580
|
|
|
581
|
+
# for POMDP only observ-fluents are assumed visible to the policy
|
|
582
|
+
if rddl.observ_fluents:
|
|
583
|
+
observed_vars = rddl.observ_fluents
|
|
584
|
+
else:
|
|
585
|
+
observed_vars = rddl.state_fluents
|
|
586
|
+
|
|
431
587
|
# evaluate the step from the policy
|
|
432
588
|
def _jax_wrapped_single_step_policy(key, policy_params, hyperparams,
|
|
433
589
|
step, subs, model_params):
|
|
434
590
|
states = {var: values
|
|
435
591
|
for (var, values) in subs.items()
|
|
436
|
-
if
|
|
592
|
+
if var in observed_vars}
|
|
437
593
|
actions = policy(key, policy_params, hyperparams, step, states)
|
|
438
594
|
key, subkey = random.split(key)
|
|
439
|
-
log = jax_step_fn(subkey, actions, subs, model_params)
|
|
440
|
-
return log
|
|
595
|
+
subs, log = jax_step_fn(subkey, actions, subs, model_params)
|
|
596
|
+
return subs, log
|
|
441
597
|
|
|
442
598
|
# do a batched step update from the policy
|
|
443
599
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
444
600
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
445
601
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
446
602
|
keys = jnp.asarray(subkeys)
|
|
447
|
-
log = jax.vmap(
|
|
603
|
+
subs, log = jax.vmap(
|
|
448
604
|
_jax_wrapped_single_step_policy,
|
|
449
605
|
in_axes=(0, None, None, None, 0, None)
|
|
450
606
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
451
|
-
subs = log['pvar']
|
|
452
607
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
453
608
|
return carry, log
|
|
454
609
|
|
|
@@ -467,7 +622,7 @@ class JaxRDDLCompiler:
|
|
|
467
622
|
# error checks
|
|
468
623
|
# ===========================================================================
|
|
469
624
|
|
|
470
|
-
def print_jax(self) -> Dict[str,
|
|
625
|
+
def print_jax(self) -> Dict[str, Any]:
|
|
471
626
|
'''Returns a dictionary containing the string representations of all
|
|
472
627
|
Jax compiled expressions from the RDDL file.
|
|
473
628
|
'''
|
|
@@ -564,7 +719,7 @@ class JaxRDDLCompiler:
|
|
|
564
719
|
}
|
|
565
720
|
|
|
566
721
|
@staticmethod
|
|
567
|
-
def get_error_codes(error):
|
|
722
|
+
def get_error_codes(error: int) -> List[int]:
|
|
568
723
|
'''Given a compacted integer error flag from the execution of Jax, and
|
|
569
724
|
decomposes it into individual error codes.
|
|
570
725
|
'''
|
|
@@ -573,7 +728,7 @@ class JaxRDDLCompiler:
|
|
|
573
728
|
return errors
|
|
574
729
|
|
|
575
730
|
@staticmethod
|
|
576
|
-
def get_error_messages(error):
|
|
731
|
+
def get_error_messages(error: int) -> List[str]:
|
|
577
732
|
'''Given a compacted integer error flag from the execution of Jax, and
|
|
578
733
|
decomposes it into error strings.
|
|
579
734
|
'''
|
|
@@ -586,28 +741,40 @@ class JaxRDDLCompiler:
|
|
|
586
741
|
# ===========================================================================
|
|
587
742
|
|
|
588
743
|
def _unwrap(self, op, expr_id, info):
|
|
589
|
-
sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
|
|
590
744
|
jax_op, name = op, None
|
|
745
|
+
model_params, relaxed_list = info
|
|
591
746
|
if isinstance(op, tuple):
|
|
592
747
|
jax_op, param = op
|
|
593
748
|
if param is not None:
|
|
594
749
|
tags, values = param
|
|
750
|
+
sep = JaxRDDLCompiler.MODEL_PARAM_TAG_SEPARATOR
|
|
595
751
|
if isinstance(tags, tuple):
|
|
596
752
|
name = sep.join(tags)
|
|
597
753
|
else:
|
|
598
754
|
name = str(tags)
|
|
599
755
|
name = f'{name}{sep}{expr_id}'
|
|
600
|
-
if name in
|
|
601
|
-
raise
|
|
602
|
-
|
|
756
|
+
if name in model_params:
|
|
757
|
+
raise RuntimeError(
|
|
758
|
+
f'Internal error: model parameter {name} is already defined.')
|
|
759
|
+
model_params[name] = (values, tags, expr_id, jax_op.__name__)
|
|
760
|
+
relaxed_list.append((param, expr_id, jax_op.__name__))
|
|
603
761
|
return jax_op, name
|
|
604
762
|
|
|
605
|
-
def
|
|
606
|
-
'''Returns a
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
763
|
+
def summarize_model_relaxations(self) -> str:
|
|
764
|
+
'''Returns a string of information about model relaxations in the
|
|
765
|
+
compiled model.'''
|
|
766
|
+
occurence_by_type = {}
|
|
767
|
+
for (_, expr_id, jax_op) in self.relaxations:
|
|
768
|
+
etype = self.traced.lookup(expr_id).etype
|
|
769
|
+
source = f'{etype[1]} ({etype[0]})'
|
|
770
|
+
sub = f'{source:<30} --> {jax_op}'
|
|
771
|
+
occurence_by_type[sub] = occurence_by_type.get(sub, 0) + 1
|
|
772
|
+
col = "{:<80} {:<10}\n"
|
|
773
|
+
table = col.format('Substitution', 'Count')
|
|
774
|
+
for (sub, occurs) in occurence_by_type.items():
|
|
775
|
+
table += col.format(sub, occurs)
|
|
776
|
+
return table
|
|
777
|
+
|
|
611
778
|
# ===========================================================================
|
|
612
779
|
# expression compilation
|
|
613
780
|
# ===========================================================================
|
|
@@ -640,7 +807,8 @@ class JaxRDDLCompiler:
|
|
|
640
807
|
raise RDDLNotImplementedError(
|
|
641
808
|
f'Internal error: expression type {expr} is not supported.\n' +
|
|
642
809
|
print_stack_trace(expr))
|
|
643
|
-
|
|
810
|
+
|
|
811
|
+
# force type cast of tensor as required by caller
|
|
644
812
|
if dtype is not None:
|
|
645
813
|
jax_expr = self._jax_cast(jax_expr, dtype)
|
|
646
814
|
|
|
@@ -660,6 +828,17 @@ class JaxRDDLCompiler:
|
|
|
660
828
|
|
|
661
829
|
return _jax_wrapped_cast
|
|
662
830
|
|
|
831
|
+
def _fix_dtype(self, value):
|
|
832
|
+
dtype = jnp.atleast_1d(value).dtype
|
|
833
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
834
|
+
return self.INT
|
|
835
|
+
elif jnp.issubdtype(dtype, jnp.floating):
|
|
836
|
+
return self.REAL
|
|
837
|
+
elif jnp.issubdtype(dtype, jnp.bool_) or jnp.issubdtype(dtype, bool):
|
|
838
|
+
return bool
|
|
839
|
+
else:
|
|
840
|
+
raise TypeError(f'Invalid type {dtype} of {value}.')
|
|
841
|
+
|
|
663
842
|
# ===========================================================================
|
|
664
843
|
# leaves
|
|
665
844
|
# ===========================================================================
|
|
@@ -669,7 +848,7 @@ class JaxRDDLCompiler:
|
|
|
669
848
|
cached_value = self.traced.cached_sim_info(expr)
|
|
670
849
|
|
|
671
850
|
def _jax_wrapped_constant(x, params, key):
|
|
672
|
-
sample = jnp.asarray(cached_value)
|
|
851
|
+
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
673
852
|
return sample, key, NORMAL
|
|
674
853
|
|
|
675
854
|
return _jax_wrapped_constant
|
|
@@ -693,7 +872,7 @@ class JaxRDDLCompiler:
|
|
|
693
872
|
cached_value = cached_info
|
|
694
873
|
|
|
695
874
|
def _jax_wrapped_object(x, params, key):
|
|
696
|
-
sample = jnp.asarray(cached_value)
|
|
875
|
+
sample = jnp.asarray(cached_value, dtype=self._fix_dtype(cached_value))
|
|
697
876
|
return sample, key, NORMAL
|
|
698
877
|
|
|
699
878
|
return _jax_wrapped_object
|
|
@@ -702,7 +881,8 @@ class JaxRDDLCompiler:
|
|
|
702
881
|
elif cached_info is None:
|
|
703
882
|
|
|
704
883
|
def _jax_wrapped_pvar_scalar(x, params, key):
|
|
705
|
-
|
|
884
|
+
value = x[var]
|
|
885
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
706
886
|
return sample, key, NORMAL
|
|
707
887
|
|
|
708
888
|
return _jax_wrapped_pvar_scalar
|
|
@@ -721,7 +901,8 @@ class JaxRDDLCompiler:
|
|
|
721
901
|
|
|
722
902
|
def _jax_wrapped_pvar_tensor_nested(x, params, key):
|
|
723
903
|
error = NORMAL
|
|
724
|
-
|
|
904
|
+
value = x[var]
|
|
905
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
725
906
|
new_slices = [None] * len(jax_nested_expr)
|
|
726
907
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
727
908
|
new_slices[i], key, err = jax_expr(x, params, key)
|
|
@@ -736,7 +917,8 @@ class JaxRDDLCompiler:
|
|
|
736
917
|
else:
|
|
737
918
|
|
|
738
919
|
def _jax_wrapped_pvar_tensor_non_nested(x, params, key):
|
|
739
|
-
|
|
920
|
+
value = x[var]
|
|
921
|
+
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
740
922
|
if slices:
|
|
741
923
|
sample = sample[slices]
|
|
742
924
|
if axis:
|
|
@@ -795,16 +977,23 @@ class JaxRDDLCompiler:
|
|
|
795
977
|
|
|
796
978
|
def _jax_arithmetic(self, expr, info):
|
|
797
979
|
_, op = expr.etype
|
|
798
|
-
|
|
980
|
+
|
|
981
|
+
# if expression is non-fluent, always use the exact operation
|
|
982
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
983
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
|
|
984
|
+
negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
985
|
+
else:
|
|
986
|
+
valid_ops = self.ARITHMETIC_OPS
|
|
987
|
+
negative_op = self.NEGATIVE
|
|
799
988
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
800
|
-
|
|
989
|
+
|
|
990
|
+
# recursively compile arguments
|
|
801
991
|
args = expr.args
|
|
802
992
|
n = len(args)
|
|
803
|
-
|
|
804
993
|
if n == 1 and op == '-':
|
|
805
994
|
arg, = args
|
|
806
995
|
jax_expr = self._jax(arg, info)
|
|
807
|
-
jax_op, jax_param = self._unwrap(
|
|
996
|
+
jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
|
|
808
997
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
809
998
|
|
|
810
999
|
elif n == 2:
|
|
@@ -819,29 +1008,42 @@ class JaxRDDLCompiler:
|
|
|
819
1008
|
|
|
820
1009
|
def _jax_relational(self, expr, info):
|
|
821
1010
|
_, op = expr.etype
|
|
822
|
-
|
|
1011
|
+
|
|
1012
|
+
# if expression is non-fluent, always use the exact operation
|
|
1013
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1014
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
|
|
1015
|
+
else:
|
|
1016
|
+
valid_ops = self.RELATIONAL_OPS
|
|
823
1017
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
824
|
-
|
|
1018
|
+
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
825
1019
|
|
|
1020
|
+
# recursively compile arguments
|
|
1021
|
+
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
826
1022
|
lhs, rhs = expr.args
|
|
827
1023
|
jax_lhs = self._jax(lhs, info)
|
|
828
1024
|
jax_rhs = self._jax(rhs, info)
|
|
829
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
830
1025
|
return self._jax_binary(
|
|
831
1026
|
jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
832
1027
|
|
|
833
1028
|
def _jax_logical(self, expr, info):
|
|
834
1029
|
_, op = expr.etype
|
|
835
|
-
|
|
1030
|
+
|
|
1031
|
+
# if expression is non-fluent, always use the exact operation
|
|
1032
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1033
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
|
|
1034
|
+
logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
1035
|
+
else:
|
|
1036
|
+
valid_ops = self.LOGICAL_OPS
|
|
1037
|
+
logical_not_op = self.LOGICAL_NOT
|
|
836
1038
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
837
1039
|
|
|
1040
|
+
# recursively compile arguments
|
|
838
1041
|
args = expr.args
|
|
839
|
-
n = len(args)
|
|
840
|
-
|
|
1042
|
+
n = len(args)
|
|
841
1043
|
if n == 1 and op == '~':
|
|
842
1044
|
arg, = args
|
|
843
1045
|
jax_expr = self._jax(arg, info)
|
|
844
|
-
jax_op, jax_param = self._unwrap(
|
|
1046
|
+
jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
|
|
845
1047
|
return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
|
|
846
1048
|
|
|
847
1049
|
elif n == 2:
|
|
@@ -856,17 +1058,21 @@ class JaxRDDLCompiler:
|
|
|
856
1058
|
|
|
857
1059
|
def _jax_aggregation(self, expr, info):
|
|
858
1060
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
859
|
-
|
|
860
1061
|
_, op = expr.etype
|
|
861
|
-
|
|
1062
|
+
|
|
1063
|
+
# if expression is non-fluent, always use the exact operation
|
|
1064
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1065
|
+
valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
|
|
1066
|
+
else:
|
|
1067
|
+
valid_ops = self.AGGREGATION_OPS
|
|
862
1068
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
863
|
-
|
|
1069
|
+
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
864
1070
|
|
|
1071
|
+
# recursively compile arguments
|
|
1072
|
+
is_floating = op not in self.AGGREGATION_BOOL
|
|
865
1073
|
* _, arg = expr.args
|
|
866
|
-
_, axes = self.traced.cached_sim_info(expr)
|
|
867
|
-
|
|
1074
|
+
_, axes = self.traced.cached_sim_info(expr)
|
|
868
1075
|
jax_expr = self._jax(arg, info)
|
|
869
|
-
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
870
1076
|
|
|
871
1077
|
def _jax_wrapped_aggregation(x, params, key):
|
|
872
1078
|
sample, key, err = jax_expr(x, params, key)
|
|
@@ -884,21 +1090,28 @@ class JaxRDDLCompiler:
|
|
|
884
1090
|
def _jax_functional(self, expr, info):
|
|
885
1091
|
_, op = expr.etype
|
|
886
1092
|
|
|
887
|
-
#
|
|
888
|
-
if
|
|
1093
|
+
# if expression is non-fluent, always use the exact operation
|
|
1094
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1095
|
+
unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
|
|
1096
|
+
binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
|
|
1097
|
+
else:
|
|
1098
|
+
unary_ops = self.KNOWN_UNARY
|
|
1099
|
+
binary_ops = self.KNOWN_BINARY
|
|
1100
|
+
|
|
1101
|
+
# recursively compile arguments
|
|
1102
|
+
if op in unary_ops:
|
|
889
1103
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
890
1104
|
arg, = expr.args
|
|
891
1105
|
jax_expr = self._jax(arg, info)
|
|
892
|
-
jax_op, jax_param = self._unwrap(
|
|
1106
|
+
jax_op, jax_param = self._unwrap(unary_ops[op], expr.id, info)
|
|
893
1107
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
894
1108
|
|
|
895
|
-
|
|
896
|
-
elif op in self.KNOWN_BINARY:
|
|
1109
|
+
elif op in binary_ops:
|
|
897
1110
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
898
1111
|
lhs, rhs = expr.args
|
|
899
1112
|
jax_lhs = self._jax(lhs, info)
|
|
900
1113
|
jax_rhs = self._jax(rhs, info)
|
|
901
|
-
jax_op, jax_param = self._unwrap(
|
|
1114
|
+
jax_op, jax_param = self._unwrap(binary_ops[op], expr.id, info)
|
|
902
1115
|
return self._jax_binary(
|
|
903
1116
|
jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
904
1117
|
|
|
@@ -921,19 +1134,19 @@ class JaxRDDLCompiler:
|
|
|
921
1134
|
f'Control operator {op} is not supported.\n' +
|
|
922
1135
|
print_stack_trace(expr))
|
|
923
1136
|
|
|
924
|
-
def _jax_if_helper(self):
|
|
925
|
-
|
|
926
|
-
def _jax_wrapped_if_calc_exact(c, a, b, param):
|
|
927
|
-
return jnp.where(c, a, b)
|
|
928
|
-
|
|
929
|
-
return _jax_wrapped_if_calc_exact
|
|
930
|
-
|
|
931
1137
|
def _jax_if(self, expr, info):
|
|
932
1138
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_CAST']
|
|
933
1139
|
JaxRDDLCompiler._check_num_args(expr, 3)
|
|
934
|
-
|
|
1140
|
+
pred, if_true, if_false = expr.args
|
|
1141
|
+
|
|
1142
|
+
# if predicate is non-fluent, always use the exact operation
|
|
1143
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1144
|
+
if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
|
|
1145
|
+
else:
|
|
1146
|
+
if_op = self.IF_HELPER
|
|
1147
|
+
jax_if, jax_param = self._unwrap(if_op, expr.id, info)
|
|
935
1148
|
|
|
936
|
-
|
|
1149
|
+
# recursively compile arguments
|
|
937
1150
|
jax_pred = self._jax(pred, info)
|
|
938
1151
|
jax_true = self._jax(if_true, info)
|
|
939
1152
|
jax_false = self._jax(if_false, info)
|
|
@@ -951,23 +1164,20 @@ class JaxRDDLCompiler:
|
|
|
951
1164
|
|
|
952
1165
|
return _jax_wrapped_if_then_else
|
|
953
1166
|
|
|
954
|
-
def _jax_switch_helper(self):
|
|
955
|
-
|
|
956
|
-
def _jax_wrapped_switch_calc_exact(pred, cases, param):
|
|
957
|
-
pred = pred[jnp.newaxis, ...]
|
|
958
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
959
|
-
assert sample.shape[0] == 1
|
|
960
|
-
return sample[0, ...]
|
|
961
|
-
|
|
962
|
-
return _jax_wrapped_switch_calc_exact
|
|
963
|
-
|
|
964
1167
|
def _jax_switch(self, expr, info):
|
|
965
|
-
|
|
1168
|
+
|
|
1169
|
+
# if expression is non-fluent, always use the exact operation
|
|
1170
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
1171
|
+
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1172
|
+
else:
|
|
1173
|
+
switch_op = self.SWITCH_HELPER
|
|
1174
|
+
jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
|
|
1175
|
+
|
|
1176
|
+
# recursively compile predicate
|
|
1177
|
+
pred, *_ = expr.args
|
|
966
1178
|
jax_pred = self._jax(pred, info)
|
|
967
|
-
jax_switch, jax_param = self._unwrap(
|
|
968
|
-
self._jax_switch_helper(), expr.id, info)
|
|
969
1179
|
|
|
970
|
-
#
|
|
1180
|
+
# recursively compile cases
|
|
971
1181
|
cases, default = self.traced.cached_sim_info(expr)
|
|
972
1182
|
jax_default = None if default is None else self._jax(default, info)
|
|
973
1183
|
jax_cases = [(jax_default if _case is None else self._jax(_case, info))
|
|
@@ -983,7 +1193,8 @@ class JaxRDDLCompiler:
|
|
|
983
1193
|
for (i, jax_case) in enumerate(jax_cases):
|
|
984
1194
|
sample_cases[i], key, err_case = jax_case(x, params, key)
|
|
985
1195
|
err |= err_case
|
|
986
|
-
sample_cases = jnp.asarray(
|
|
1196
|
+
sample_cases = jnp.asarray(
|
|
1197
|
+
sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
987
1198
|
|
|
988
1199
|
# predicate (enum) is an integer - use it to extract from case array
|
|
989
1200
|
param = params.get(jax_param, None)
|
|
@@ -1179,30 +1390,28 @@ class JaxRDDLCompiler:
|
|
|
1179
1390
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1180
1391
|
key, subkey = random.split(key)
|
|
1181
1392
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1182
|
-
sample = scale * jnp.power(-jnp.
|
|
1393
|
+
sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
|
|
1183
1394
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1184
1395
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1185
1396
|
return sample, key, err
|
|
1186
1397
|
|
|
1187
1398
|
return _jax_wrapped_distribution_weibull
|
|
1188
1399
|
|
|
1189
|
-
def _jax_bernoulli_helper(self):
|
|
1190
|
-
|
|
1191
|
-
def _jax_wrapped_calc_bernoulli_exact(key, prob, param):
|
|
1192
|
-
return random.bernoulli(key, prob)
|
|
1193
|
-
|
|
1194
|
-
return _jax_wrapped_calc_bernoulli_exact
|
|
1195
|
-
|
|
1196
1400
|
def _jax_bernoulli(self, expr, info):
|
|
1197
1401
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1198
1402
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1199
|
-
jax_bern, jax_param = self._unwrap(
|
|
1200
|
-
self._jax_bernoulli_helper(), expr.id, info)
|
|
1201
|
-
|
|
1202
1403
|
arg_prob, = expr.args
|
|
1404
|
+
|
|
1405
|
+
# if probability is non-fluent, always use the exact operation
|
|
1406
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1407
|
+
bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
1408
|
+
else:
|
|
1409
|
+
bern_op = self.BERNOULLI_HELPER
|
|
1410
|
+
jax_bern, jax_param = self._unwrap(bern_op, expr.id, info)
|
|
1411
|
+
|
|
1412
|
+
# recursively compile arguments
|
|
1203
1413
|
jax_prob = self._jax(arg_prob, info)
|
|
1204
1414
|
|
|
1205
|
-
# uses the implicit JAX subroutine
|
|
1206
1415
|
def _jax_wrapped_distribution_bernoulli(x, params, key):
|
|
1207
1416
|
prob, key, err = jax_prob(x, params, key)
|
|
1208
1417
|
key, subkey = random.split(key)
|
|
@@ -1266,8 +1475,8 @@ class JaxRDDLCompiler:
|
|
|
1266
1475
|
def _jax_wrapped_distribution_binomial(x, params, key):
|
|
1267
1476
|
trials, key, err2 = jax_trials(x, params, key)
|
|
1268
1477
|
prob, key, err1 = jax_prob(x, params, key)
|
|
1269
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1270
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1478
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1479
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1271
1480
|
key, subkey = random.split(key)
|
|
1272
1481
|
dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
|
|
1273
1482
|
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
@@ -1290,11 +1499,10 @@ class JaxRDDLCompiler:
|
|
|
1290
1499
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1291
1500
|
trials, key, err2 = jax_trials(x, params, key)
|
|
1292
1501
|
prob, key, err1 = jax_prob(x, params, key)
|
|
1293
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1294
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1502
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1503
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1295
1504
|
key, subkey = random.split(key)
|
|
1296
|
-
dist = tfp.distributions.NegativeBinomial(
|
|
1297
|
-
total_count=trials, probs=prob)
|
|
1505
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
|
|
1298
1506
|
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
1299
1507
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1300
1508
|
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
@@ -1316,7 +1524,7 @@ class JaxRDDLCompiler:
|
|
|
1316
1524
|
shape, key, err1 = jax_shape(x, params, key)
|
|
1317
1525
|
rate, key, err2 = jax_rate(x, params, key)
|
|
1318
1526
|
key, subkey = random.split(key)
|
|
1319
|
-
sample = random.beta(key=subkey, a=shape, b=rate)
|
|
1527
|
+
sample = random.beta(key=subkey, a=shape, b=rate, dtype=self.REAL)
|
|
1320
1528
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (rate > 0)))
|
|
1321
1529
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1322
1530
|
return sample, key, err
|
|
@@ -1325,23 +1533,35 @@ class JaxRDDLCompiler:
|
|
|
1325
1533
|
|
|
1326
1534
|
def _jax_geometric(self, expr, info):
|
|
1327
1535
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1328
|
-
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1329
|
-
|
|
1536
|
+
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1330
1537
|
arg_prob, = expr.args
|
|
1331
1538
|
jax_prob = self._jax(arg_prob, info)
|
|
1332
|
-
floor_op, jax_param = self._unwrap(
|
|
1333
|
-
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1334
1539
|
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
prob
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1540
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1541
|
+
|
|
1542
|
+
# prob is non-fluent: do not reparameterize
|
|
1543
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1544
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1545
|
+
key, subkey = random.split(key)
|
|
1546
|
+
sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
|
|
1547
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1548
|
+
err |= (out_of_bounds * ERR)
|
|
1549
|
+
return sample, key, err
|
|
1550
|
+
|
|
1551
|
+
else:
|
|
1552
|
+
floor_op, jax_param = self._unwrap(
|
|
1553
|
+
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1554
|
+
|
|
1555
|
+
# reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
|
|
1556
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1557
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1558
|
+
key, subkey = random.split(key)
|
|
1559
|
+
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1560
|
+
param = params.get(jax_param, None)
|
|
1561
|
+
sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
1562
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1563
|
+
err |= (out_of_bounds * ERR)
|
|
1564
|
+
return sample, key, err
|
|
1345
1565
|
|
|
1346
1566
|
return _jax_wrapped_distribution_geometric
|
|
1347
1567
|
|
|
@@ -1359,7 +1579,7 @@ class JaxRDDLCompiler:
|
|
|
1359
1579
|
shape, key, err1 = jax_shape(x, params, key)
|
|
1360
1580
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1361
1581
|
key, subkey = random.split(key)
|
|
1362
|
-
sample = scale * random.pareto(key=subkey, b=shape)
|
|
1582
|
+
sample = scale * random.pareto(key=subkey, b=shape, dtype=self.REAL)
|
|
1363
1583
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1364
1584
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1365
1585
|
return sample, key, err
|
|
@@ -1377,7 +1597,8 @@ class JaxRDDLCompiler:
|
|
|
1377
1597
|
def _jax_wrapped_distribution_t(x, params, key):
|
|
1378
1598
|
df, key, err = jax_df(x, params, key)
|
|
1379
1599
|
key, subkey = random.split(key)
|
|
1380
|
-
sample = random.t(
|
|
1600
|
+
sample = random.t(
|
|
1601
|
+
key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1381
1602
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1382
1603
|
err |= (out_of_bounds * ERR)
|
|
1383
1604
|
return sample, key, err
|
|
@@ -1464,7 +1685,7 @@ class JaxRDDLCompiler:
|
|
|
1464
1685
|
scale, key, err2 = jax_scale(x, params, key)
|
|
1465
1686
|
key, subkey = random.split(key)
|
|
1466
1687
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1467
|
-
sample = jnp.log(1.0 - jnp.
|
|
1688
|
+
sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
|
|
1468
1689
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1469
1690
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1470
1691
|
return sample, key, err
|
|
@@ -1516,25 +1737,21 @@ class JaxRDDLCompiler:
|
|
|
1516
1737
|
# random variables with enum support
|
|
1517
1738
|
# ===========================================================================
|
|
1518
1739
|
|
|
1519
|
-
def _jax_discrete_helper(self):
|
|
1520
|
-
|
|
1521
|
-
def _jax_wrapped_discrete_calc_exact(key, prob, param):
|
|
1522
|
-
logits = jnp.log(prob)
|
|
1523
|
-
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
1524
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1525
|
-
jnp.all(prob >= 0),
|
|
1526
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1527
|
-
return sample, out_of_bounds
|
|
1528
|
-
|
|
1529
|
-
return _jax_wrapped_discrete_calc_exact
|
|
1530
|
-
|
|
1531
1740
|
def _jax_discrete(self, expr, info, unnorm):
|
|
1532
1741
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
1533
1742
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1534
|
-
jax_discrete, jax_param = self._unwrap(
|
|
1535
|
-
self._jax_discrete_helper(), expr.id, info)
|
|
1536
|
-
|
|
1537
1743
|
ordered_args = self.traced.cached_sim_info(expr)
|
|
1744
|
+
|
|
1745
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1746
|
+
has_fluent_arg = any(self.traced.cached_is_fluent(arg)
|
|
1747
|
+
for arg in ordered_args)
|
|
1748
|
+
if self.compile_non_fluent_exact and not has_fluent_arg:
|
|
1749
|
+
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1750
|
+
else:
|
|
1751
|
+
discrete_op = self.DISCRETE_HELPER
|
|
1752
|
+
jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
|
|
1753
|
+
|
|
1754
|
+
# compile probability expressions
|
|
1538
1755
|
jax_probs = [self._jax(arg, info) for arg in ordered_args]
|
|
1539
1756
|
|
|
1540
1757
|
def _jax_wrapped_distribution_discrete(x, params, key):
|
|
@@ -1561,12 +1778,18 @@ class JaxRDDLCompiler:
|
|
|
1561
1778
|
|
|
1562
1779
|
def _jax_discrete_pvar(self, expr, info, unnorm):
|
|
1563
1780
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_DISCRETE']
|
|
1564
|
-
JaxRDDLCompiler._check_num_args(expr,
|
|
1565
|
-
jax_discrete, jax_param = self._unwrap(
|
|
1566
|
-
self._jax_discrete_helper(), expr.id, info)
|
|
1567
|
-
|
|
1781
|
+
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1568
1782
|
_, args = expr.args
|
|
1569
1783
|
arg, = args
|
|
1784
|
+
|
|
1785
|
+
# if probabilities are non-fluent, then always sample exact
|
|
1786
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
|
|
1787
|
+
discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
1788
|
+
else:
|
|
1789
|
+
discrete_op = self.DISCRETE_HELPER
|
|
1790
|
+
jax_discrete, jax_param = self._unwrap(discrete_op, expr.id, info)
|
|
1791
|
+
|
|
1792
|
+
# compile probability function
|
|
1570
1793
|
jax_probs = self._jax(arg, info)
|
|
1571
1794
|
|
|
1572
1795
|
def _jax_wrapped_distribution_discrete_pvar(x, params, key):
|
|
@@ -1687,7 +1910,7 @@ class JaxRDDLCompiler:
|
|
|
1687
1910
|
out_of_bounds = jnp.logical_not(jnp.all(alpha > 0))
|
|
1688
1911
|
error |= (out_of_bounds * ERR)
|
|
1689
1912
|
key, subkey = random.split(key)
|
|
1690
|
-
Gamma = random.gamma(key=subkey, a=alpha)
|
|
1913
|
+
Gamma = random.gamma(key=subkey, a=alpha, dtype=self.REAL)
|
|
1691
1914
|
sample = Gamma / jnp.sum(Gamma, axis=-1, keepdims=True)
|
|
1692
1915
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1693
1916
|
return sample, key, error
|
|
@@ -1706,8 +1929,8 @@ class JaxRDDLCompiler:
|
|
|
1706
1929
|
def _jax_wrapped_distribution_multinomial(x, params, key):
|
|
1707
1930
|
trials, key, err1 = jax_trials(x, params, key)
|
|
1708
1931
|
prob, key, err2 = jax_prob(x, params, key)
|
|
1709
|
-
trials = jnp.asarray(trials, self.REAL)
|
|
1710
|
-
prob = jnp.asarray(prob, self.REAL)
|
|
1932
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1933
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1711
1934
|
key, subkey = random.split(key)
|
|
1712
1935
|
dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
|
|
1713
1936
|
sample = dist.sample(seed=subkey).astype(self.INT)
|