pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -29,26 +29,16 @@
|
|
|
29
29
|
#
|
|
30
30
|
# ***********************************************************************
|
|
31
31
|
|
|
32
|
+
import termcolor
|
|
33
|
+
from typing import Any, Dict, Optional, Set, Tuple, Union
|
|
32
34
|
|
|
33
|
-
|
|
34
|
-
import traceback
|
|
35
|
-
from typing import Callable, Dict, Tuple, Union
|
|
36
|
-
|
|
35
|
+
import numpy as np
|
|
37
36
|
import jax
|
|
38
37
|
import jax.numpy as jnp
|
|
39
38
|
import jax.random as random
|
|
40
39
|
import jax.scipy as scipy
|
|
41
40
|
|
|
42
|
-
from
|
|
43
|
-
|
|
44
|
-
# more robust approach - if user does not have this or broken try to continue
|
|
45
|
-
try:
|
|
46
|
-
from tensorflow_probability.substrates import jax as tfp
|
|
47
|
-
except Exception:
|
|
48
|
-
raise_warning('Failed to import tensorflow-probability: '
|
|
49
|
-
'compilation of some probability distributions will fail.', 'red')
|
|
50
|
-
traceback.print_exc()
|
|
51
|
-
tfp = None
|
|
41
|
+
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
52
42
|
|
|
53
43
|
|
|
54
44
|
def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
|
|
@@ -59,1421 +49,1599 @@ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32)
|
|
|
59
49
|
return literals
|
|
60
50
|
|
|
61
51
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
52
|
+
# branching sigmoid to help reduce numerical issues
|
|
53
|
+
@jax.custom_jvp
|
|
54
|
+
def stable_sigmoid(x: jnp.ndarray) -> jnp.ndarray:
|
|
55
|
+
return jnp.where(x >= 0, 1.0 / (1.0 + jnp.exp(-x)), jnp.exp(x) / (1.0 + jnp.exp(x)))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@stable_sigmoid.defjvp
|
|
59
|
+
def stable_sigmoid_jvp(primals, tangents):
|
|
60
|
+
(x,), (x_dot,) = primals, tangents
|
|
61
|
+
s = stable_sigmoid(x)
|
|
62
|
+
primal_out = s
|
|
63
|
+
tangent_out = x_dot * s * (1.0 - s)
|
|
64
|
+
return primal_out, tangent_out
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# branching tanh to help reduce numerical issues
|
|
68
|
+
@jax.custom_jvp
|
|
69
|
+
def stable_tanh(x: jnp.ndarray) -> jnp.ndarray:
|
|
70
|
+
ax = jnp.abs(x)
|
|
71
|
+
small = jnp.where(
|
|
72
|
+
ax < 20.0,
|
|
73
|
+
jnp.expm1(2.0 * ax) / (jnp.expm1(2.0 * ax) + 2.0),
|
|
74
|
+
1.0 - 2.0 * jnp.exp(-2.0 * ax)
|
|
75
|
+
)
|
|
76
|
+
return jnp.sign(x) * small
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@stable_tanh.defjvp
|
|
80
|
+
def stable_tanh_jvp(primals, tangents):
|
|
81
|
+
(x,), (x_dot,) = primals, tangents
|
|
82
|
+
t = stable_tanh(x)
|
|
83
|
+
tangent_out = x_dot * (1.0 - t * t)
|
|
84
|
+
return t, tangent_out
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# it seems JAX uses the stability trick already
|
|
88
|
+
def stable_softmax_weight_sum(logits: jnp.ndarray,
|
|
89
|
+
values: jnp.ndarray,
|
|
90
|
+
axis: Union[int, Tuple[int, ...]]) -> jnp.ndarray:
|
|
91
|
+
return jnp.sum(values * jax.nn.softmax(logits), axis=axis)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
95
|
+
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
96
|
+
Unlike its parent class, this class treats all fluents as real-valued, and
|
|
97
|
+
replaces all mathematical operations by equivalent ones with a well defined
|
|
98
|
+
(e.g. non-zero) gradient where appropriate.
|
|
99
|
+
'''
|
|
100
|
+
|
|
101
|
+
def __init__(self, *args,
|
|
102
|
+
cpfs_without_grad: Optional[Set[str]]=None,
|
|
103
|
+
print_warnings: bool=True,
|
|
104
|
+
**kwargs) -> None:
|
|
105
|
+
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
106
|
+
differentiable are converted to approximate forms that have defined gradients.
|
|
107
|
+
|
|
108
|
+
:param *args: arguments to pass to base compiler
|
|
109
|
+
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
110
|
+
through gradient trick)
|
|
111
|
+
:param print_warnings: whether to print warnings
|
|
112
|
+
:param *kwargs: keyword arguments to pass to base compiler
|
|
113
|
+
'''
|
|
114
|
+
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
if cpfs_without_grad is None:
|
|
117
|
+
cpfs_without_grad = set()
|
|
118
|
+
self.cpfs_without_grad = cpfs_without_grad
|
|
119
|
+
self.print_warnings = print_warnings
|
|
120
|
+
|
|
121
|
+
# actions and CPFs must be continuous
|
|
122
|
+
pvars_cast = set()
|
|
123
|
+
for (var, values) in self.init_values.items():
|
|
124
|
+
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
125
|
+
if not np.issubdtype(np.result_type(values), np.floating):
|
|
126
|
+
pvars_cast.add(var)
|
|
127
|
+
if self.print_warnings and pvars_cast:
|
|
128
|
+
print(termcolor.colored(
|
|
129
|
+
f'[INFO] Compiler will cast pvars {pvars_cast} to float.', 'dark_grey'))
|
|
130
|
+
|
|
131
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
132
|
+
kwargs = super().get_kwargs()
|
|
133
|
+
kwargs['cpfs_without_grad'] = self.cpfs_without_grad
|
|
134
|
+
kwargs['print_warnings'] = self.print_warnings
|
|
135
|
+
return kwargs
|
|
136
|
+
|
|
137
|
+
def _jax_stop_grad(self, jax_expr):
|
|
138
|
+
def _jax_wrapped_stop_grad(fls, nfls, params, key):
|
|
139
|
+
sample, key, error, params = jax_expr(fls, nfls, params, key)
|
|
140
|
+
sample = jax.lax.stop_gradient(sample)
|
|
141
|
+
return sample, key, error, params
|
|
142
|
+
return _jax_wrapped_stop_grad
|
|
143
|
+
|
|
144
|
+
def _compile_cpfs(self, aux):
|
|
145
|
+
|
|
146
|
+
# cpfs will all be cast to float
|
|
147
|
+
cpfs_cast = set()
|
|
148
|
+
jax_cpfs = {}
|
|
149
|
+
for (_, cpfs) in self.levels.items():
|
|
150
|
+
for cpf in cpfs:
|
|
151
|
+
_, expr = self.rddl.cpfs[cpf]
|
|
152
|
+
jax_cpfs[cpf] = self._jax(expr, aux, dtype=self.REAL)
|
|
153
|
+
if self.rddl.variable_ranges[cpf] != 'real':
|
|
154
|
+
cpfs_cast.add(cpf)
|
|
155
|
+
if cpf in self.cpfs_without_grad:
|
|
156
|
+
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
157
|
+
|
|
158
|
+
if self.print_warnings and cpfs_cast:
|
|
159
|
+
print(termcolor.colored(
|
|
160
|
+
f'[INFO] Compiler will cast CPFs {cpfs_cast} to float.', 'dark_grey'))
|
|
161
|
+
if self.print_warnings and self.cpfs_without_grad:
|
|
162
|
+
print(termcolor.colored(
|
|
163
|
+
f'[INFO] Gradient disabled for CPFs {self.cpfs_without_grad}.', 'dark_grey'))
|
|
164
|
+
|
|
165
|
+
return jax_cpfs
|
|
166
|
+
|
|
167
|
+
def _jax_unary_with_param(self, jax_expr, jax_op):
|
|
168
|
+
def _jax_wrapped_unary_op_with_param(fls, nfls, params, key):
|
|
169
|
+
sample, key, err, params = jax_expr(fls, nfls, params, key)
|
|
170
|
+
sample = self.ONE * sample
|
|
171
|
+
sample, params = jax_op(sample, params)
|
|
172
|
+
return sample, key, err, params
|
|
173
|
+
return _jax_wrapped_unary_op_with_param
|
|
174
|
+
|
|
175
|
+
def _jax_binary_with_param(self, jax_lhs, jax_rhs, jax_op):
|
|
176
|
+
def _jax_wrapped_binary_op_with_param(fls, nfls, params, key):
|
|
177
|
+
sample1, key, err1, params = jax_lhs(fls, nfls, params, key)
|
|
178
|
+
sample2, key, err2, params = jax_rhs(fls, nfls, params, key)
|
|
179
|
+
sample1 = self.ONE * sample1
|
|
180
|
+
sample2 = self.ONE * sample2
|
|
181
|
+
sample, params = jax_op(sample1, sample2, params)
|
|
182
|
+
err = err1 | err2
|
|
183
|
+
return sample, key, err, params
|
|
184
|
+
return _jax_wrapped_binary_op_with_param
|
|
185
|
+
|
|
186
|
+
def _jax_unary_helper_with_param(self, expr, aux, jax_op):
|
|
187
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
188
|
+
arg, = expr.args
|
|
189
|
+
jax_arg = self._jax(arg, aux)
|
|
190
|
+
return self._jax_unary_with_param(jax_arg, jax_op)
|
|
191
|
+
|
|
192
|
+
def _jax_binary_helper_with_param(self, expr, aux, jax_op):
|
|
193
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
194
|
+
lhs, rhs = expr.args
|
|
195
|
+
jax_lhs = self._jax(lhs, aux)
|
|
196
|
+
jax_rhs = self._jax(rhs, aux)
|
|
197
|
+
return self._jax_binary_with_param(jax_lhs, jax_rhs, jax_op)
|
|
198
|
+
|
|
199
|
+
def _jax_kron(self, expr, aux):
|
|
200
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
201
|
+
arg, = expr.args
|
|
202
|
+
arg = self._jax(arg, aux)
|
|
203
|
+
return arg
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# ===============================================================================
|
|
207
|
+
# relational relaxations
|
|
208
|
+
# ===============================================================================
|
|
209
|
+
|
|
210
|
+
# https://arxiv.org/abs/2110.05651
|
|
211
|
+
class SigmoidRelational(JaxRDDLCompilerWithGrad):
|
|
94
212
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
95
213
|
|
|
96
|
-
def __init__(self,
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
def
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
id_ =
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
sample = jnp.sum(literals * softmax, axis=axis)
|
|
214
|
+
def __init__(self, *args, sigmoid_weight: float=10.,
|
|
215
|
+
use_sigmoid_ste: bool=True, use_tanh_ste: bool=True,
|
|
216
|
+
**kwargs) -> None:
|
|
217
|
+
super(SigmoidRelational, self).__init__(*args, **kwargs)
|
|
218
|
+
self.sigmoid_weight = float(sigmoid_weight)
|
|
219
|
+
self.use_sigmoid_ste = use_sigmoid_ste
|
|
220
|
+
self.use_tanh_ste = use_tanh_ste
|
|
221
|
+
|
|
222
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
223
|
+
kwargs = super().get_kwargs()
|
|
224
|
+
kwargs['sigmoid_weight'] = self.sigmoid_weight
|
|
225
|
+
kwargs['use_sigmoid_ste'] = self.use_sigmoid_ste
|
|
226
|
+
kwargs['use_tanh_ste'] = self.use_tanh_ste
|
|
227
|
+
return kwargs
|
|
228
|
+
|
|
229
|
+
def _jax_greater(self, expr, aux):
|
|
230
|
+
if not self.traced.cached_is_fluent(expr):
|
|
231
|
+
return super()._jax_greater(expr, aux)
|
|
232
|
+
id_ = expr.id
|
|
233
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
234
|
+
aux['overriden'][id_] = __class__.__name__
|
|
235
|
+
def greater_op(x, y, params):
|
|
236
|
+
sample = stable_sigmoid(params[id_] * (x - y))
|
|
237
|
+
if self.use_sigmoid_ste:
|
|
238
|
+
sample = sample + jax.lax.stop_gradient(jnp.greater(x, y) - sample)
|
|
239
|
+
return sample, params
|
|
240
|
+
return self._jax_binary_helper_with_param(expr, aux, greater_op)
|
|
241
|
+
|
|
242
|
+
def _jax_greater_equal(self, expr, aux):
|
|
243
|
+
if not self.traced.cached_is_fluent(expr):
|
|
244
|
+
return super()._jax_greater(expr, aux)
|
|
245
|
+
id_ = expr.id
|
|
246
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
247
|
+
aux['overriden'][id_] = __class__.__name__
|
|
248
|
+
def greater_equal_op(x, y, params):
|
|
249
|
+
sample = stable_sigmoid(params[id_] * (x - y))
|
|
250
|
+
if self.use_sigmoid_ste:
|
|
251
|
+
sample = sample + jax.lax.stop_gradient(jnp.greater_equal(x, y) - sample)
|
|
135
252
|
return sample, params
|
|
136
|
-
return
|
|
253
|
+
return self._jax_binary_helper_with_param(expr, aux, greater_equal_op)
|
|
254
|
+
|
|
255
|
+
def _jax_less(self, expr, aux):
|
|
256
|
+
if not self.traced.cached_is_fluent(expr):
|
|
257
|
+
return super()._jax_less(expr, aux)
|
|
258
|
+
id_ = expr.id
|
|
259
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
260
|
+
aux['overriden'][id_] = __class__.__name__
|
|
261
|
+
def less_op(x, y, params):
|
|
262
|
+
sample = stable_sigmoid(params[id_] * (y - x))
|
|
263
|
+
if self.use_sigmoid_ste:
|
|
264
|
+
sample = sample + jax.lax.stop_gradient(jnp.less(x, y) - sample)
|
|
265
|
+
return sample, params
|
|
266
|
+
return self._jax_binary_helper_with_param(expr, aux, less_op)
|
|
267
|
+
|
|
268
|
+
def _jax_less_equal(self, expr, aux):
|
|
269
|
+
if not self.traced.cached_is_fluent(expr):
|
|
270
|
+
return super()._jax_less(expr, aux)
|
|
271
|
+
id_ = expr.id
|
|
272
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
273
|
+
aux['overriden'][id_] = __class__.__name__
|
|
274
|
+
def less_equal_op(x, y, params):
|
|
275
|
+
sample = stable_sigmoid(params[id_] * (y - x))
|
|
276
|
+
if self.use_sigmoid_ste:
|
|
277
|
+
sample = sample + jax.lax.stop_gradient(jnp.less_equal(x, y) - sample)
|
|
278
|
+
return sample, params
|
|
279
|
+
return self._jax_binary_helper_with_param(expr, aux, less_equal_op)
|
|
280
|
+
|
|
281
|
+
def _jax_equal(self, expr, aux):
|
|
282
|
+
if not self.traced.cached_is_fluent(expr):
|
|
283
|
+
return super()._jax_equal(expr, aux)
|
|
284
|
+
id_ = expr.id
|
|
285
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
286
|
+
aux['overriden'][id_] = __class__.__name__
|
|
287
|
+
def equal_op(x, y, params):
|
|
288
|
+
sample = 1. - jnp.square(stable_tanh(params[id_] * (y - x)))
|
|
289
|
+
if self.use_tanh_ste:
|
|
290
|
+
sample = sample + jax.lax.stop_gradient(jnp.equal(x, y) - sample)
|
|
291
|
+
return sample, params
|
|
292
|
+
return self._jax_binary_helper_with_param(expr, aux, equal_op)
|
|
293
|
+
|
|
294
|
+
def _jax_not_equal(self, expr, aux):
|
|
295
|
+
if not self.traced.cached_is_fluent(expr):
|
|
296
|
+
return super()._jax_not_equal(expr, aux)
|
|
297
|
+
id_ = expr.id
|
|
298
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
299
|
+
aux['overriden'][id_] = __class__.__name__
|
|
300
|
+
def not_equal_op(x, y, params):
|
|
301
|
+
sample = jnp.square(stable_tanh(params[id_] * (y - x)))
|
|
302
|
+
if self.use_tanh_ste:
|
|
303
|
+
sample = sample + jax.lax.stop_gradient(jnp.not_equal(x, y) - sample)
|
|
304
|
+
return sample, params
|
|
305
|
+
return self._jax_binary_helper_with_param(expr, aux, not_equal_op)
|
|
306
|
+
|
|
307
|
+
def _jax_sgn(self, expr, aux):
|
|
308
|
+
if not self.traced.cached_is_fluent(expr):
|
|
309
|
+
return super()._jax_sgn(expr, aux)
|
|
310
|
+
id_ = expr.id
|
|
311
|
+
aux['params'][id_] = self.sigmoid_weight
|
|
312
|
+
aux['overriden'][id_] = __class__.__name__
|
|
313
|
+
def sgn_op(x, params):
|
|
314
|
+
sample = stable_tanh(params[id_] * x)
|
|
315
|
+
if self.use_tanh_ste:
|
|
316
|
+
sample = sample + jax.lax.stop_gradient(jnp.sign(x) - sample)
|
|
317
|
+
return sample, params
|
|
318
|
+
return self._jax_unary_helper_with_param(expr, aux, sgn_op)
|
|
137
319
|
|
|
138
|
-
def __str__(self) -> str:
|
|
139
|
-
return f'Sigmoid comparison with weight {self.weight}'
|
|
140
320
|
|
|
321
|
+
class SoftmaxArgmax(JaxRDDLCompilerWithGrad):
|
|
322
|
+
'''Argmin/argmax operations approximated using softmax functions.'''
|
|
141
323
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# - soft rounding
|
|
146
|
-
#
|
|
147
|
-
# ===========================================================================
|
|
148
|
-
|
|
149
|
-
class Rounding(metaclass=ABCMeta):
|
|
150
|
-
'''Base class for approximate rounding operations.'''
|
|
151
|
-
|
|
152
|
-
@abstractmethod
|
|
153
|
-
def floor(self, id, init_params):
|
|
154
|
-
pass
|
|
324
|
+
def __init__(self, *args, argmax_weight: float=10., **kwargs) -> None:
|
|
325
|
+
super(SoftmaxArgmax, self).__init__(*args, **kwargs)
|
|
326
|
+
self.argmax_weight = float(argmax_weight)
|
|
155
327
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
328
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
329
|
+
kwargs = super().get_kwargs()
|
|
330
|
+
kwargs['argmax_weight'] = self.argmax_weight
|
|
331
|
+
return kwargs
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def soft_argmax(x: jnp.ndarray, w: float, axes: Union[int, Tuple[int, ...]]) -> jnp.ndarray:
|
|
335
|
+
literals = enumerate_literals(jnp.shape(x), axis=axes)
|
|
336
|
+
return stable_softmax_weight_sum(w * x, literals, axis=axes)
|
|
337
|
+
|
|
338
|
+
def _jax_argmax(self, expr, aux):
|
|
339
|
+
if not self.traced.cached_is_fluent(expr):
|
|
340
|
+
return super()._jax_argmax(expr, aux)
|
|
341
|
+
id_ = expr.id
|
|
342
|
+
aux['params'][id_] = self.argmax_weight
|
|
343
|
+
aux['overriden'][id_] = __class__.__name__
|
|
344
|
+
arg = expr.args[-1]
|
|
345
|
+
_, axes = self.traced.cached_sim_info(expr)
|
|
346
|
+
jax_expr = self._jax(arg, aux)
|
|
347
|
+
def argmax_op(x, params):
|
|
348
|
+
sample = self.soft_argmax(x, params[id_], axes)
|
|
349
|
+
return sample, params
|
|
350
|
+
return self._jax_unary_with_param(jax_expr, argmax_op)
|
|
351
|
+
|
|
352
|
+
def _jax_argmin(self, expr, aux):
|
|
353
|
+
if not self.traced.cached_is_fluent(expr):
|
|
354
|
+
return super()._jax_argmin(expr, aux)
|
|
355
|
+
id_ = expr.id
|
|
356
|
+
aux['params'][id_] = self.argmax_weight
|
|
357
|
+
aux['overriden'][id_] = __class__.__name__
|
|
358
|
+
arg = expr.args[-1]
|
|
359
|
+
_, axes = self.traced.cached_sim_info(expr)
|
|
360
|
+
jax_expr = self._jax(arg, aux)
|
|
361
|
+
def argmin_op(x, params):
|
|
362
|
+
sample = self.soft_argmax(-x, params[id_], axes)
|
|
363
|
+
return sample, params
|
|
364
|
+
return self._jax_unary_with_param(jax_expr, argmin_op)
|
|
365
|
+
|
|
159
366
|
|
|
367
|
+
# ===============================================================================
|
|
368
|
+
# logical relaxations
|
|
369
|
+
# ===============================================================================
|
|
160
370
|
|
|
161
|
-
class
|
|
162
|
-
'''
|
|
371
|
+
class ProductNormLogical(JaxRDDLCompilerWithGrad):
|
|
372
|
+
'''Product t-norm given by the expression (x, y) -> x * y.'''
|
|
163
373
|
|
|
164
|
-
def __init__(self,
|
|
165
|
-
self.
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
374
|
+
def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
|
|
375
|
+
super(ProductNormLogical, self).__init__(*args, **kwargs)
|
|
376
|
+
self.use_logic_ste = use_logic_ste
|
|
377
|
+
|
|
378
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
379
|
+
kwargs = super().get_kwargs()
|
|
380
|
+
kwargs['use_logic_ste'] = self.use_logic_ste
|
|
381
|
+
return kwargs
|
|
382
|
+
|
|
383
|
+
def _jax_not(self, expr, aux):
|
|
384
|
+
if not self.traced.cached_is_fluent(expr):
|
|
385
|
+
return super()._jax_not(expr, aux)
|
|
386
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
387
|
+
def not_op(x):
|
|
388
|
+
sample = 1. - x
|
|
389
|
+
if self.use_logic_ste:
|
|
390
|
+
hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
|
|
391
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
392
|
+
return sample
|
|
393
|
+
return self._jax_unary_helper(expr, aux, not_op)
|
|
394
|
+
|
|
395
|
+
def _jax_and(self, expr, aux):
|
|
396
|
+
if not self.traced.cached_is_fluent(expr):
|
|
397
|
+
return super()._jax_and(expr, aux)
|
|
398
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
399
|
+
def and_op(x, y):
|
|
400
|
+
sample = jnp.multiply(x, y)
|
|
401
|
+
if self.use_logic_ste:
|
|
402
|
+
hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
403
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
404
|
+
return sample
|
|
405
|
+
return self._jax_nary_helper(expr, aux, and_op)
|
|
406
|
+
|
|
407
|
+
def _jax_or(self, expr, aux):
|
|
408
|
+
if not self.traced.cached_is_fluent(expr):
|
|
409
|
+
return super()._jax_or(expr, aux)
|
|
410
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
411
|
+
def or_op(x, y):
|
|
412
|
+
sample = 1. - (1. - x) * (1. - y)
|
|
413
|
+
if self.use_logic_ste:
|
|
414
|
+
hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
415
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
416
|
+
return sample
|
|
417
|
+
return self._jax_nary_helper(expr, aux, or_op)
|
|
418
|
+
|
|
419
|
+
def _jax_xor(self, expr, aux):
|
|
420
|
+
if not self.traced.cached_is_fluent(expr):
|
|
421
|
+
return super()._jax_xor(expr, aux)
|
|
422
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
423
|
+
def xor_op(x, y):
|
|
424
|
+
sample = (1. - (1. - x) * (1. - y)) * (1. - x * y)
|
|
425
|
+
if self.use_logic_ste:
|
|
426
|
+
hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
427
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
428
|
+
return sample
|
|
429
|
+
return self._jax_binary_helper(expr, aux, xor_op)
|
|
430
|
+
|
|
431
|
+
def _jax_implies(self, expr, aux):
|
|
432
|
+
if not self.traced.cached_is_fluent(expr):
|
|
433
|
+
return super()._jax_implies(expr, aux)
|
|
434
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
435
|
+
def implies_op(x, y):
|
|
436
|
+
sample = 1. - x * (1. - y)
|
|
437
|
+
if self.use_logic_ste:
|
|
438
|
+
hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
|
|
439
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
440
|
+
return sample
|
|
441
|
+
return self._jax_binary_helper(expr, aux, implies_op)
|
|
442
|
+
|
|
443
|
+
def _jax_equiv(self, expr, aux):
|
|
444
|
+
if not self.traced.cached_is_fluent(expr):
|
|
445
|
+
return super()._jax_equiv(expr, aux)
|
|
446
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
447
|
+
def equiv_op(x, y):
|
|
448
|
+
sample = (1. - x * (1. - y)) * (1. - y * (1. - x))
|
|
449
|
+
if self.use_logic_ste:
|
|
450
|
+
hard_sample = jnp.logical_and(
|
|
451
|
+
jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
|
|
452
|
+
hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
|
|
453
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
454
|
+
return sample
|
|
455
|
+
return self._jax_binary_helper(expr, aux, equiv_op)
|
|
456
|
+
|
|
457
|
+
def _jax_forall(self, expr, aux):
|
|
458
|
+
if not self.traced.cached_is_fluent(expr):
|
|
459
|
+
return super()._jax_forall(expr, aux)
|
|
460
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
461
|
+
def forall_op(x, axis):
|
|
462
|
+
sample = jnp.prod(x, axis=axis)
|
|
463
|
+
if self.use_logic_ste:
|
|
464
|
+
hard_sample = jnp.all(x, axis=axis)
|
|
465
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
466
|
+
return sample
|
|
467
|
+
return self._jax_aggregation_helper(expr, aux, forall_op)
|
|
468
|
+
|
|
469
|
+
def _jax_exists(self, expr, aux):
|
|
470
|
+
if not self.traced.cached_is_fluent(expr):
|
|
471
|
+
return super()._jax_exists(expr, aux)
|
|
472
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
473
|
+
def exists_op(x, axis):
|
|
474
|
+
sample = 1. - jnp.prod(1. - x, axis=axis)
|
|
475
|
+
if self.use_logic_ste:
|
|
476
|
+
hard_sample = jnp.any(x, axis=axis)
|
|
477
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
478
|
+
return sample
|
|
479
|
+
return self._jax_aggregation_helper(expr, aux, exists_op)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
class GodelNormLogical(JaxRDDLCompilerWithGrad):
|
|
483
|
+
'''Godel t-norm given by the expression (x, y) -> min(x, y).'''
|
|
178
484
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
485
|
+
def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
|
|
486
|
+
super(GodelNormLogical, self).__init__(*args, **kwargs)
|
|
487
|
+
self.use_logic_ste = use_logic_ste
|
|
488
|
+
|
|
489
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
490
|
+
kwargs = super().get_kwargs()
|
|
491
|
+
kwargs['use_logic_ste'] = self.use_logic_ste
|
|
492
|
+
return kwargs
|
|
493
|
+
|
|
494
|
+
def _jax_not(self, expr, aux):
|
|
495
|
+
if not self.traced.cached_is_fluent(expr):
|
|
496
|
+
return super()._jax_not(expr, aux)
|
|
497
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
498
|
+
def not_op(x):
|
|
499
|
+
sample = 1. - x
|
|
500
|
+
if self.use_logic_ste:
|
|
501
|
+
hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
|
|
502
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
503
|
+
return sample
|
|
504
|
+
return self._jax_unary_helper(expr, aux, not_op)
|
|
505
|
+
|
|
506
|
+
def _jax_and(self, expr, aux):
|
|
507
|
+
if not self.traced.cached_is_fluent(expr):
|
|
508
|
+
return super()._jax_and(expr, aux)
|
|
509
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
510
|
+
def and_op(x, y):
|
|
511
|
+
sample = jnp.minimum(x, y)
|
|
512
|
+
if self.use_logic_ste:
|
|
513
|
+
hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
514
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
515
|
+
return sample
|
|
516
|
+
return self._jax_nary_helper(expr, aux, and_op)
|
|
517
|
+
|
|
518
|
+
def _jax_or(self, expr, aux):
|
|
519
|
+
if not self.traced.cached_is_fluent(expr):
|
|
520
|
+
return super()._jax_or(expr, aux)
|
|
521
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
522
|
+
def or_op(x, y):
|
|
523
|
+
sample = jnp.maximum(x, y)
|
|
524
|
+
if self.use_logic_ste:
|
|
525
|
+
hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
526
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
527
|
+
return sample
|
|
528
|
+
return self._jax_nary_helper(expr, aux, or_op)
|
|
529
|
+
|
|
530
|
+
def _jax_xor(self, expr, aux):
|
|
531
|
+
if not self.traced.cached_is_fluent(expr):
|
|
532
|
+
return super()._jax_xor(expr, aux)
|
|
533
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
534
|
+
def xor_op(x, y):
|
|
535
|
+
sample = jnp.minimum(jnp.maximum(x, y), 1. - jnp.minimum(x, y))
|
|
536
|
+
if self.use_logic_ste:
|
|
537
|
+
hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
538
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
539
|
+
return sample
|
|
540
|
+
return self._jax_binary_helper(expr, aux, xor_op)
|
|
541
|
+
|
|
542
|
+
def _jax_implies(self, expr, aux):
|
|
543
|
+
if not self.traced.cached_is_fluent(expr):
|
|
544
|
+
return super()._jax_implies(expr, aux)
|
|
545
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
546
|
+
def implies_op(x, y):
|
|
547
|
+
sample = jnp.maximum(1. - x, y)
|
|
548
|
+
if self.use_logic_ste:
|
|
549
|
+
hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
|
|
550
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
551
|
+
return sample
|
|
552
|
+
return self._jax_binary_helper(expr, aux, implies_op)
|
|
553
|
+
|
|
554
|
+
def _jax_equiv(self, expr, aux):
|
|
555
|
+
if not self.traced.cached_is_fluent(expr):
|
|
556
|
+
return super()._jax_equiv(expr, aux)
|
|
557
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
558
|
+
def equiv_op(x, y):
|
|
559
|
+
sample = jnp.minimum(jnp.maximum(1. - x, y), jnp.maximum(1. - y, x))
|
|
560
|
+
if self.use_logic_ste:
|
|
561
|
+
hard_sample = jnp.logical_and(
|
|
562
|
+
jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
|
|
563
|
+
hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
|
|
564
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
565
|
+
return sample
|
|
566
|
+
return self._jax_binary_helper(expr, aux, equiv_op)
|
|
567
|
+
|
|
568
|
+
def _jax_forall(self, expr, aux):
|
|
569
|
+
if not self.traced.cached_is_fluent(expr):
|
|
570
|
+
return super()._jax_forall(expr, aux)
|
|
571
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
572
|
+
def all_op(x, axis):
|
|
573
|
+
sample = jnp.min(x, axis=axis)
|
|
574
|
+
if self.use_logic_ste:
|
|
575
|
+
hard_sample = jnp.all(x, axis=axis)
|
|
576
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
577
|
+
return sample
|
|
578
|
+
return self._jax_aggregation_helper(expr, aux, all_op)
|
|
579
|
+
|
|
580
|
+
def _jax_exists(self, expr, aux):
|
|
581
|
+
if not self.traced.cached_is_fluent(expr):
|
|
582
|
+
return super()._jax_exists(expr, aux)
|
|
583
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
584
|
+
def exists_op(x, axis):
|
|
585
|
+
sample = jnp.max(x, axis=axis)
|
|
586
|
+
if self.use_logic_ste:
|
|
587
|
+
hard_sample = jnp.any(x, axis=axis)
|
|
588
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
589
|
+
return sample
|
|
590
|
+
return self._jax_aggregation_helper(expr, aux, exists_op)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
class LukasiewiczNormLogical(JaxRDDLCompilerWithGrad):
|
|
594
|
+
'''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
|
|
595
|
+
|
|
596
|
+
def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
|
|
597
|
+
super(LukasiewiczNormLogical, self).__init__(*args, **kwargs)
|
|
598
|
+
self.use_logic_ste = use_logic_ste
|
|
599
|
+
|
|
600
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
601
|
+
kwargs = super().get_kwargs()
|
|
602
|
+
kwargs['use_logic_ste'] = self.use_logic_ste
|
|
603
|
+
return kwargs
|
|
604
|
+
|
|
605
|
+
def _jax_not(self, expr, aux):
|
|
606
|
+
if not self.traced.cached_is_fluent(expr):
|
|
607
|
+
return super()._jax_not(expr, aux)
|
|
608
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
609
|
+
def not_op(x):
|
|
610
|
+
sample = 1. - x
|
|
611
|
+
if self.use_logic_ste:
|
|
612
|
+
hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
|
|
613
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
614
|
+
return sample
|
|
615
|
+
return self._jax_unary_helper(expr, aux, not_op)
|
|
616
|
+
|
|
617
|
+
def _jax_and(self, expr, aux):
|
|
618
|
+
if not self.traced.cached_is_fluent(expr):
|
|
619
|
+
return super()._jax_and(expr, aux)
|
|
620
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
621
|
+
def and_op(x, y):
|
|
622
|
+
sample = jax.nn.relu(x + y - 1.)
|
|
623
|
+
if self.use_logic_ste:
|
|
624
|
+
hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
625
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
626
|
+
return sample
|
|
627
|
+
return self._jax_nary_helper(expr, aux, and_op)
|
|
628
|
+
|
|
629
|
+
def _jax_or(self, expr, aux):
|
|
630
|
+
if not self.traced.cached_is_fluent(expr):
|
|
631
|
+
return super()._jax_or(expr, aux)
|
|
632
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
633
|
+
def or_op(x, y):
|
|
634
|
+
sample = 1. - jax.nn.relu(1. - x - y)
|
|
635
|
+
if self.use_logic_ste:
|
|
636
|
+
hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
637
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
638
|
+
return sample
|
|
639
|
+
return self._jax_nary_helper(expr, aux, or_op)
|
|
640
|
+
|
|
641
|
+
def _jax_xor(self, expr, aux):
|
|
642
|
+
if not self.traced.cached_is_fluent(expr):
|
|
643
|
+
return super()._jax_xor(expr, aux)
|
|
644
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
645
|
+
def xor_op(x, y):
|
|
646
|
+
sample = jax.nn.relu(1. - jnp.abs(1. - x - y))
|
|
647
|
+
if self.use_logic_ste:
|
|
648
|
+
hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
|
|
649
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
650
|
+
return sample
|
|
651
|
+
return self._jax_binary_helper(expr, aux, xor_op)
|
|
652
|
+
|
|
653
|
+
def _jax_implies(self, expr, aux):
|
|
654
|
+
if not self.traced.cached_is_fluent(expr):
|
|
655
|
+
return super()._jax_implies(expr, aux)
|
|
656
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
657
|
+
def implies_op(x, y):
|
|
658
|
+
sample = 1. - jax.nn.relu(x - y)
|
|
659
|
+
if self.use_logic_ste:
|
|
660
|
+
hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
|
|
661
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
662
|
+
return sample
|
|
663
|
+
return self._jax_binary_helper(expr, aux, implies_op)
|
|
664
|
+
|
|
665
|
+
def _jax_equiv(self, expr, aux):
|
|
666
|
+
if not self.traced.cached_is_fluent(expr):
|
|
667
|
+
return super()._jax_equiv(expr, aux)
|
|
668
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
669
|
+
def equiv_op(x, y):
|
|
670
|
+
sample = jax.nn.relu(1. - jnp.abs(x - y))
|
|
671
|
+
if self.use_logic_ste:
|
|
672
|
+
hard_sample = jnp.logical_and(
|
|
673
|
+
jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
|
|
674
|
+
hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
|
|
675
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
676
|
+
return sample
|
|
677
|
+
return self._jax_binary_helper(expr, aux, equiv_op)
|
|
678
|
+
|
|
679
|
+
def _jax_forall(self, expr, aux):
|
|
680
|
+
if not self.traced.cached_is_fluent(expr):
|
|
681
|
+
return super()._jax_forall(expr, aux)
|
|
682
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
683
|
+
def forall_op(x, axis):
|
|
684
|
+
sample = jax.nn.relu(jnp.sum(x - 1., axis=axis) + 1.)
|
|
685
|
+
if self.use_logic_ste:
|
|
686
|
+
hard_sample = jnp.all(x, axis=axis)
|
|
687
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
688
|
+
return sample
|
|
689
|
+
return self._jax_aggregation_helper(expr, aux, forall_op)
|
|
690
|
+
|
|
691
|
+
def _jax_exists(self, expr, aux):
|
|
692
|
+
if not self.traced.cached_is_fluent(expr):
|
|
693
|
+
return super()._jax_exists(expr, aux)
|
|
694
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
695
|
+
def exists_op(x, axis):
|
|
696
|
+
sample = 1. - jax.nn.relu(jnp.sum(-x, axis=axis) + 1.)
|
|
697
|
+
if self.use_logic_ste:
|
|
698
|
+
hard_sample = jnp.any(x, axis=axis)
|
|
699
|
+
sample = sample + jax.lax.stop_gradient(hard_sample - sample)
|
|
700
|
+
return sample
|
|
701
|
+
return self._jax_aggregation_helper(expr, aux, exists_op)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
# ===============================================================================
|
|
705
|
+
# function relaxations
|
|
706
|
+
# ===============================================================================
|
|
707
|
+
|
|
708
|
+
class SafeSqrt(JaxRDDLCompilerWithGrad):
|
|
709
|
+
'''Sqrt operation without negative underflow.'''
|
|
710
|
+
|
|
711
|
+
def __init__(self, *args, sqrt_eps: float=1e-14, **kwargs) -> None:
|
|
712
|
+
super(SafeSqrt, self).__init__(*args, **kwargs)
|
|
713
|
+
self.sqrt_eps = float(sqrt_eps)
|
|
714
|
+
|
|
715
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
716
|
+
kwargs = super().get_kwargs()
|
|
717
|
+
kwargs['sqrt_eps'] = self.sqrt_eps
|
|
718
|
+
return kwargs
|
|
719
|
+
|
|
720
|
+
def _jax_sqrt(self, expr, aux):
|
|
721
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
722
|
+
def safe_sqrt_op(x):
|
|
723
|
+
return jnp.sqrt(x + self.sqrt_eps)
|
|
724
|
+
return self._jax_unary_helper(expr, aux, safe_sqrt_op, at_least_int=True)
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class SoftFloor(JaxRDDLCompilerWithGrad):
|
|
728
|
+
'''Floor and ceil operations approximated using soft operations.'''
|
|
729
|
+
|
|
730
|
+
def __init__(self, *args, floor_weight: float=10.,
|
|
731
|
+
use_floor_ste: bool=True, **kwargs) -> None:
|
|
732
|
+
super(SoftFloor, self).__init__(*args, **kwargs)
|
|
733
|
+
self.floor_weight = float(floor_weight)
|
|
734
|
+
self.use_floor_ste = use_floor_ste
|
|
735
|
+
|
|
736
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
737
|
+
kwargs = super().get_kwargs()
|
|
738
|
+
kwargs['floor_weight'] = self.floor_weight
|
|
739
|
+
kwargs['use_floor_ste'] = self.use_floor_ste
|
|
740
|
+
return kwargs
|
|
741
|
+
|
|
742
|
+
@staticmethod
|
|
743
|
+
def soft_floor(x: jnp.ndarray, w: float) -> jnp.ndarray:
|
|
744
|
+
s = x - jnp.floor(x)
|
|
745
|
+
return jnp.floor(x) + 0.5 * (
|
|
746
|
+
1. + stable_tanh(w * (s - 1.) / 2.) / stable_tanh(w / 4.))
|
|
747
|
+
|
|
748
|
+
def _jax_floor(self, expr, aux):
|
|
749
|
+
if not self.traced.cached_is_fluent(expr):
|
|
750
|
+
return super()._jax_floor(expr, aux)
|
|
751
|
+
id_ = expr.id
|
|
752
|
+
aux['params'][id_] = self.floor_weight
|
|
753
|
+
aux['overriden'][id_] = __class__.__name__
|
|
754
|
+
def floor_op(x, params):
|
|
755
|
+
sample = self.soft_floor(x, params[id_])
|
|
756
|
+
if self.use_floor_ste:
|
|
757
|
+
sample = sample + jax.lax.stop_gradient(jnp.floor(x) - sample)
|
|
758
|
+
return sample, params
|
|
759
|
+
return self._jax_unary_helper_with_param(expr, aux, floor_op)
|
|
760
|
+
|
|
761
|
+
def _jax_ceil(self, expr, aux):
|
|
762
|
+
if not self.traced.cached_is_fluent(expr):
|
|
763
|
+
return super()._jax_ceil(expr, aux)
|
|
764
|
+
id_ = expr.id
|
|
765
|
+
aux['params'][id_] = self.floor_weight
|
|
766
|
+
aux['overriden'][id_] = __class__.__name__
|
|
767
|
+
def ceil_op(x, params):
|
|
768
|
+
sample = -self.soft_floor(-x, params[id_])
|
|
769
|
+
if self.use_floor_ste:
|
|
770
|
+
sample = sample + jax.lax.stop_gradient(jnp.ceil(x) - sample)
|
|
771
|
+
return sample, params
|
|
772
|
+
return self._jax_unary_helper_with_param(expr, aux, ceil_op)
|
|
773
|
+
|
|
774
|
+
def _jax_div(self, expr, aux):
|
|
775
|
+
if not self.traced.cached_is_fluent(expr):
|
|
776
|
+
return super()._jax_div(expr, aux)
|
|
777
|
+
id_ = expr.id
|
|
778
|
+
aux['params'][id_] = self.floor_weight
|
|
779
|
+
aux['overriden'][id_] = __class__.__name__
|
|
780
|
+
def div_op(x, y, params):
|
|
781
|
+
sample = self.soft_floor(x / y, params[id_])
|
|
782
|
+
if self.use_floor_ste:
|
|
783
|
+
sample = sample + jax.lax.stop_gradient(jnp.floor_divide(x, y) - sample)
|
|
784
|
+
return sample, params
|
|
785
|
+
return self._jax_binary_helper_with_param(expr, aux, div_op)
|
|
786
|
+
|
|
787
|
+
def _jax_mod(self, expr, aux):
|
|
788
|
+
if not self.traced.cached_is_fluent(expr):
|
|
789
|
+
return super()._jax_mod(expr, aux)
|
|
790
|
+
id_ = expr.id
|
|
791
|
+
aux['params'][id_] = self.floor_weight
|
|
792
|
+
aux['overriden'][id_] = __class__.__name__
|
|
793
|
+
def mod_op(x, y, params):
|
|
794
|
+
div = self.soft_floor(x / y, params[id_])
|
|
795
|
+
if self.use_floor_ste:
|
|
796
|
+
div = div + jax.lax.stop_gradient(jnp.floor_divide(x, y) - div)
|
|
797
|
+
sample = x - y * div
|
|
798
|
+
return sample, params
|
|
799
|
+
return self._jax_binary_helper_with_param(expr, aux, mod_op)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
class SoftRound(JaxRDDLCompilerWithGrad):
|
|
803
|
+
'''Round operations approximated using soft operations.'''
|
|
804
|
+
|
|
805
|
+
def __init__(self, *args, round_weight: float=10.,
|
|
806
|
+
use_round_ste: bool=True, **kwargs) -> None:
|
|
807
|
+
super(SoftRound, self).__init__(*args, **kwargs)
|
|
808
|
+
self.round_weight = float(round_weight)
|
|
809
|
+
self.use_round_ste = use_round_ste
|
|
810
|
+
|
|
811
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
812
|
+
kwargs = super().get_kwargs()
|
|
813
|
+
kwargs['round_weight'] = self.round_weight
|
|
814
|
+
kwargs['use_round_ste'] = self.use_round_ste
|
|
815
|
+
return kwargs
|
|
816
|
+
|
|
817
|
+
def _jax_round(self, expr, aux):
|
|
818
|
+
if not self.traced.cached_is_fluent(expr):
|
|
819
|
+
return super()._jax_round(expr, aux)
|
|
820
|
+
id_ = expr.id
|
|
821
|
+
aux['params'][id_] = self.round_weight
|
|
822
|
+
aux['overriden'][id_] = __class__.__name__
|
|
823
|
+
def round_op(x, params):
|
|
184
824
|
param = params[id_]
|
|
185
825
|
m = jnp.floor(x) + 0.5
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
826
|
+
sample = m + 0.5 * stable_tanh(param * (x - m)) / stable_tanh(param / 2.)
|
|
827
|
+
if self.use_round_ste:
|
|
828
|
+
sample = sample + jax.lax.stop_gradient(jnp.round(x) - sample)
|
|
829
|
+
return sample, params
|
|
830
|
+
return self._jax_unary_helper_with_param(expr, aux, round_op)
|
|
189
831
|
|
|
190
|
-
def __str__(self) -> str:
|
|
191
|
-
return f'SoftFloor and SoftRound with weight {self.weight}'
|
|
192
832
|
|
|
833
|
+
# ===============================================================================
|
|
834
|
+
# control flow relaxations
|
|
835
|
+
# ===============================================================================
|
|
193
836
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
# - abstract class
|
|
197
|
-
# - standard complement
|
|
198
|
-
#
|
|
199
|
-
# ===========================================================================
|
|
837
|
+
class LinearIfElse(JaxRDDLCompilerWithGrad):
|
|
838
|
+
'''Approximate if else statement as a linear combination.'''
|
|
200
839
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
@abstractmethod
|
|
205
|
-
def __call__(self, id, init_params):
|
|
206
|
-
pass
|
|
840
|
+
def __init__(self, *args, use_if_else_ste: bool=True, **kwargs) -> None:
|
|
841
|
+
super(LinearIfElse, self).__init__(*args, **kwargs)
|
|
842
|
+
self.use_if_else_ste = use_if_else_ste
|
|
207
843
|
|
|
844
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
845
|
+
kwargs = super().get_kwargs()
|
|
846
|
+
kwargs['use_if_else_ste'] = self.use_if_else_ste
|
|
847
|
+
return kwargs
|
|
208
848
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
213
|
-
@staticmethod
|
|
214
|
-
def _jax_wrapped_calc_not_approx(x, params):
|
|
215
|
-
return 1.0 - x, params
|
|
216
|
-
|
|
217
|
-
def __call__(self, id, init_params):
|
|
218
|
-
return self._jax_wrapped_calc_not_approx
|
|
219
|
-
|
|
220
|
-
def __str__(self) -> str:
|
|
221
|
-
return 'Standard complement'
|
|
222
|
-
|
|
849
|
+
def _jax_if(self, expr, aux):
|
|
850
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 3)
|
|
851
|
+
pred, if_true, if_false = expr.args
|
|
223
852
|
|
|
224
|
-
#
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
#
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
853
|
+
# if predicate is non-fluent, always use the exact operation
|
|
854
|
+
if not self.traced.cached_is_fluent(pred):
|
|
855
|
+
return super()._jax_if(expr, aux)
|
|
856
|
+
|
|
857
|
+
# recursively compile arguments
|
|
858
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
859
|
+
jax_pred = self._jax(pred, aux)
|
|
860
|
+
jax_true = self._jax(if_true, aux)
|
|
861
|
+
jax_false = self._jax(if_false, aux)
|
|
862
|
+
|
|
863
|
+
def _jax_wrapped_if_then_else_linear(fls, nfls, params, key):
|
|
864
|
+
sample_pred, key, err1, params = jax_pred(fls, nfls, params, key)
|
|
865
|
+
sample_true, key, err2, params = jax_true(fls, nfls, params, key)
|
|
866
|
+
sample_false, key, err3, params = jax_false(fls, nfls, params, key)
|
|
867
|
+
if self.use_if_else_ste:
|
|
868
|
+
hard_pred = (sample_pred > 0.5).astype(sample_pred.dtype)
|
|
869
|
+
sample_pred = sample_pred + jax.lax.stop_gradient(hard_pred - sample_pred)
|
|
870
|
+
sample = sample_pred * sample_true + (1 - sample_pred) * sample_false
|
|
871
|
+
err = err1 | err2 | err3
|
|
872
|
+
return sample, key, err, params
|
|
873
|
+
return _jax_wrapped_if_then_else_linear
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
class SoftmaxSwitch(JaxRDDLCompilerWithGrad):
|
|
877
|
+
'''Softmax switch control flow using a probabilistic interpretation.'''
|
|
878
|
+
|
|
879
|
+
def __init__(self, *args, switch_weight: float=10., **kwargs) -> None:
|
|
880
|
+
super(SoftmaxSwitch, self).__init__(*args, **kwargs)
|
|
881
|
+
self.switch_weight = float(switch_weight)
|
|
882
|
+
|
|
883
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
884
|
+
kwargs = super().get_kwargs()
|
|
885
|
+
kwargs['switch_weight'] = self.switch_weight
|
|
886
|
+
return kwargs
|
|
887
|
+
|
|
888
|
+
def _jax_switch(self, expr, aux):
|
|
889
|
+
|
|
890
|
+
# if predicate is non-fluent, always use the exact operation
|
|
891
|
+
# case conditions are currently only literals so they are non-fluent
|
|
892
|
+
pred = expr.args[0]
|
|
893
|
+
if not self.traced.cached_is_fluent(pred):
|
|
894
|
+
return super()._jax_switch(expr, aux)
|
|
895
|
+
|
|
896
|
+
id_ = expr.id
|
|
897
|
+
aux['params'][id_] = self.switch_weight
|
|
898
|
+
aux['overriden'][id_] = __class__.__name__
|
|
899
|
+
|
|
900
|
+
# recursively compile predicate
|
|
901
|
+
jax_pred = self._jax(pred, aux)
|
|
902
|
+
|
|
903
|
+
# recursively compile cases
|
|
904
|
+
cases, default = self.traced.cached_sim_info(expr)
|
|
905
|
+
jax_default = None if default is None else self._jax(default, aux)
|
|
906
|
+
jax_cases = [
|
|
907
|
+
(jax_default if _case is None else self._jax(_case, aux))
|
|
908
|
+
for _case in cases
|
|
909
|
+
]
|
|
910
|
+
|
|
911
|
+
def _jax_wrapped_switch_softmax(fls, nfls, params, key):
|
|
912
|
+
|
|
913
|
+
# sample predicate
|
|
914
|
+
sample_pred, key, err, params = jax_pred(fls, nfls, params, key)
|
|
915
|
+
|
|
916
|
+
# sample cases
|
|
917
|
+
sample_cases = []
|
|
918
|
+
for jax_case in jax_cases:
|
|
919
|
+
sample, key, err_case, params = jax_case(fls, nfls, params, key)
|
|
920
|
+
sample_cases.append(sample)
|
|
921
|
+
err = err | err_case
|
|
922
|
+
sample_cases = jnp.asarray(sample_cases)
|
|
923
|
+
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
924
|
+
|
|
925
|
+
# replace integer indexing with softmax
|
|
926
|
+
sample_pred = jnp.broadcast_to(
|
|
927
|
+
sample_pred[jnp.newaxis, ...], shape=jnp.shape(sample_cases))
|
|
928
|
+
literals = enumerate_literals(jnp.shape(sample_cases), axis=0)
|
|
929
|
+
proximity = -jnp.square(sample_pred - literals)
|
|
930
|
+
logits = params[id_] * proximity
|
|
931
|
+
sample = stable_softmax_weight_sum(logits, sample_cases, axis=0)
|
|
932
|
+
return sample, key, err, params
|
|
933
|
+
return _jax_wrapped_switch_softmax
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
# ===============================================================================
|
|
937
|
+
# distribution relaxations - Geometric
|
|
938
|
+
# ===============================================================================
|
|
939
|
+
|
|
940
|
+
class ReparameterizedGeometric(JaxRDDLCompilerWithGrad):
|
|
941
|
+
|
|
942
|
+
def __init__(self, *args,
|
|
943
|
+
geometric_floor_weight: float=10.,
|
|
944
|
+
geometric_eps: float=1e-14, **kwargs) -> None:
|
|
945
|
+
super(ReparameterizedGeometric, self).__init__(*args, **kwargs)
|
|
946
|
+
self.geometric_floor_weight = float(geometric_floor_weight)
|
|
947
|
+
self.geometric_eps = float(geometric_eps)
|
|
948
|
+
|
|
949
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
950
|
+
kwargs = super().get_kwargs()
|
|
951
|
+
kwargs['geometric_floor_weight'] = self.geometric_floor_weight
|
|
952
|
+
kwargs['geometric_eps'] = self.geometric_eps
|
|
953
|
+
return kwargs
|
|
954
|
+
|
|
955
|
+
def _jax_geometric(self, expr, aux):
|
|
956
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
957
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
958
|
+
arg_prob, = expr.args
|
|
959
|
+
|
|
960
|
+
# if prob is non-fluent, always use the exact operation
|
|
961
|
+
if not self.traced.cached_is_fluent(arg_prob):
|
|
962
|
+
return super()._jax_geometric(expr, aux)
|
|
963
|
+
|
|
964
|
+
id_ = expr.id
|
|
965
|
+
aux['params'][id_] = (self.geometric_floor_weight, self.geometric_eps)
|
|
966
|
+
aux['overriden'][id_] = __class__.__name__
|
|
967
|
+
|
|
968
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
969
|
+
|
|
970
|
+
def _jax_wrapped_distribution_geometric_reparam(fls, nfls, params, key):
|
|
971
|
+
w, eps = params[id_]
|
|
972
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
973
|
+
key, subkey = random.split(key)
|
|
974
|
+
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
975
|
+
sample = 1. + SoftFloor.soft_floor(jnp.log1p(-U) / jnp.log1p(-prob + eps), w=w)
|
|
976
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
977
|
+
err = err | (out_of_bounds * ERR)
|
|
978
|
+
return sample, key, err, params
|
|
979
|
+
return _jax_wrapped_distribution_geometric_reparam
|
|
234
980
|
|
|
235
|
-
class TNorm(metaclass=ABCMeta):
|
|
236
|
-
'''Base class for fuzzy differentiable t-norms.'''
|
|
237
|
-
|
|
238
|
-
@abstractmethod
|
|
239
|
-
def norm(self, id, init_params):
|
|
240
|
-
'''Elementwise t-norm of x and y.'''
|
|
241
|
-
pass
|
|
242
|
-
|
|
243
|
-
@abstractmethod
|
|
244
|
-
def norms(self, id, init_params):
|
|
245
|
-
'''T-norm computed for tensor x along axis.'''
|
|
246
|
-
pass
|
|
247
|
-
|
|
248
981
|
|
|
249
|
-
class
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
982
|
+
class DeterminizedGeometric(JaxRDDLCompilerWithGrad):
|
|
983
|
+
|
|
984
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
985
|
+
super(DeterminizedGeometric, self).__init__(*args, **kwargs)
|
|
986
|
+
|
|
987
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
988
|
+
return super().get_kwargs()
|
|
989
|
+
|
|
990
|
+
def _jax_geometric(self, expr, aux):
|
|
991
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
992
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
993
|
+
arg_prob, = expr.args
|
|
255
994
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
995
|
+
# if prob is non-fluent, always use the exact operation
|
|
996
|
+
if not self.traced.cached_is_fluent(arg_prob):
|
|
997
|
+
return super()._jax_geometric(expr, aux)
|
|
998
|
+
|
|
999
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1000
|
+
|
|
1001
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
262
1002
|
|
|
263
|
-
|
|
264
|
-
|
|
1003
|
+
def _jax_wrapped_distribution_geometric_determinized(fls, nfls, params, key):
|
|
1004
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1005
|
+
sample = 1. / prob
|
|
1006
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1007
|
+
err = err | (out_of_bounds * ERR)
|
|
1008
|
+
return sample, key, err, params
|
|
1009
|
+
return _jax_wrapped_distribution_geometric_determinized
|
|
265
1010
|
|
|
266
|
-
def __str__(self) -> str:
|
|
267
|
-
return 'Product t-norm'
|
|
268
|
-
|
|
269
1011
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
@staticmethod
|
|
281
|
-
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
282
|
-
return jnp.min(x, axis=axis), params
|
|
1012
|
+
# ===============================================================================
|
|
1013
|
+
# distribution relaxations - Bernoulli
|
|
1014
|
+
# ===============================================================================
|
|
1015
|
+
|
|
1016
|
+
class ReparameterizedSigmoidBernoulli(JaxRDDLCompilerWithGrad):
|
|
1017
|
+
|
|
1018
|
+
def __init__(self, *args, bernoulli_sigmoid_weight: float=10., **kwargs) -> None:
|
|
1019
|
+
super(ReparameterizedSigmoidBernoulli, self).__init__(*args, **kwargs)
|
|
1020
|
+
self.bernoulli_sigmoid_weight = float(bernoulli_sigmoid_weight)
|
|
283
1021
|
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
1022
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1023
|
+
kwargs = super().get_kwargs()
|
|
1024
|
+
kwargs['bernoulli_sigmoid_weight'] = self.bernoulli_sigmoid_weight
|
|
1025
|
+
return kwargs
|
|
1026
|
+
|
|
1027
|
+
def _jax_bernoulli(self, expr, aux):
|
|
1028
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1029
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1030
|
+
arg_prob, = expr.args
|
|
1031
|
+
|
|
1032
|
+
# if prob is non-fluent, always use the exact operation
|
|
1033
|
+
if not self.traced.cached_is_fluent(arg_prob):
|
|
1034
|
+
return super()._jax_bernoulli(expr, aux)
|
|
1035
|
+
|
|
1036
|
+
id_ = expr.id
|
|
1037
|
+
aux['params'][id_] = self.bernoulli_sigmoid_weight
|
|
1038
|
+
aux['overriden'][id_] = __class__.__name__
|
|
290
1039
|
|
|
291
|
-
|
|
292
|
-
'''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
|
|
293
|
-
|
|
294
|
-
@staticmethod
|
|
295
|
-
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
296
|
-
land = jax.nn.relu(x + y - 1.0)
|
|
297
|
-
return land, params
|
|
1040
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
298
1041
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
1042
|
+
def _jax_wrapped_distribution_bernoulli_reparam(fls, nfls, params, key):
|
|
1043
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1044
|
+
key, subkey = random.split(key)
|
|
1045
|
+
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1046
|
+
sample = stable_sigmoid(params[id_] * (prob - U))
|
|
1047
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1048
|
+
err = err | (out_of_bounds * ERR)
|
|
1049
|
+
return sample, key, err, params
|
|
1050
|
+
return _jax_wrapped_distribution_bernoulli_reparam
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
class GumbelSoftmaxBernoulli(JaxRDDLCompilerWithGrad):
|
|
1054
|
+
|
|
1055
|
+
def __init__(self, *args,
|
|
1056
|
+
bernoulli_softmax_weight: float=10.,
|
|
1057
|
+
bernoulli_eps: float=1e-14, **kwargs) -> None:
|
|
1058
|
+
super(GumbelSoftmaxBernoulli, self).__init__(*args, **kwargs)
|
|
1059
|
+
self.bernoulli_softmax_weight = float(bernoulli_softmax_weight)
|
|
1060
|
+
self.bernoulli_eps = float(bernoulli_eps)
|
|
1061
|
+
|
|
1062
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1063
|
+
kwargs = super().get_kwargs()
|
|
1064
|
+
kwargs['bernoulli_softmax_weight'] = self.bernoulli_softmax_weight
|
|
1065
|
+
kwargs['bernoulli_eps'] = self.bernoulli_eps
|
|
1066
|
+
return kwargs
|
|
1067
|
+
|
|
1068
|
+
def _jax_bernoulli(self, expr, aux):
|
|
1069
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1070
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1071
|
+
arg_prob, = expr.args
|
|
1072
|
+
|
|
1073
|
+
# if prob is non-fluent, always use the exact operation
|
|
1074
|
+
if not self.traced.cached_is_fluent(arg_prob):
|
|
1075
|
+
return super()._jax_bernoulli(expr, aux)
|
|
306
1076
|
|
|
307
|
-
|
|
308
|
-
|
|
1077
|
+
id_ = expr.id
|
|
1078
|
+
aux['params'][id_] = (self.bernoulli_softmax_weight, self.bernoulli_eps)
|
|
1079
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1080
|
+
|
|
1081
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1082
|
+
|
|
1083
|
+
def _jax_wrapped_distribution_bernoulli_gumbel_softmax(fls, nfls, params, key):
|
|
1084
|
+
w, eps = params[id_]
|
|
1085
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1086
|
+
probs = jnp.stack([1. - prob, prob], axis=-1)
|
|
1087
|
+
key, subkey = random.split(key)
|
|
1088
|
+
g = random.gumbel(key=subkey, shape=jnp.shape(probs), dtype=self.REAL)
|
|
1089
|
+
sample = SoftmaxArgmax.soft_argmax(g + jnp.log(probs + eps), w=w, axes=-1)
|
|
1090
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1091
|
+
err = err | (out_of_bounds * ERR)
|
|
1092
|
+
return sample, key, err, params
|
|
1093
|
+
return _jax_wrapped_distribution_bernoulli_gumbel_softmax
|
|
309
1094
|
|
|
310
|
-
def __str__(self) -> str:
|
|
311
|
-
return 'Lukasiewicz t-norm'
|
|
312
1095
|
|
|
1096
|
+
class DeterminizedBernoulli(JaxRDDLCompilerWithGrad):
|
|
313
1097
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
317
|
-
|
|
318
|
-
def __init__(self, p: float=2.0) -> None:
|
|
319
|
-
self.p = float(p)
|
|
320
|
-
|
|
321
|
-
def norm(self, id, init_params):
|
|
322
|
-
id_ = str(id)
|
|
323
|
-
init_params[id_] = self.p
|
|
324
|
-
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
325
|
-
base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
|
|
326
|
-
arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
|
|
327
|
-
land = jax.nn.relu(1.0 - arg)
|
|
328
|
-
return land, params
|
|
329
|
-
return _jax_wrapped_calc_and_approx
|
|
330
|
-
|
|
331
|
-
def norms(self, id, init_params):
|
|
332
|
-
id_ = str(id)
|
|
333
|
-
init_params[id_] = self.p
|
|
334
|
-
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
335
|
-
arg = jax.nn.relu(1.0 - x)
|
|
336
|
-
for ax in sorted(axis, reverse=True):
|
|
337
|
-
arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
|
|
338
|
-
forall = jax.nn.relu(1.0 - arg)
|
|
339
|
-
return forall, params
|
|
340
|
-
return _jax_wrapped_calc_forall_approx
|
|
341
|
-
|
|
342
|
-
def __str__(self) -> str:
|
|
343
|
-
return f'Yager({self.p}) t-norm'
|
|
344
|
-
|
|
1098
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
1099
|
+
super(DeterminizedBernoulli, self).__init__(*args, **kwargs)
|
|
345
1100
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
# - abstract sampler
|
|
349
|
-
# - Gumbel-softmax sampler
|
|
350
|
-
# - determinization
|
|
351
|
-
#
|
|
352
|
-
# ===========================================================================
|
|
1101
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1102
|
+
return super().get_kwargs()
|
|
353
1103
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
@abstractmethod
|
|
366
|
-
def binomial(self, id, init_params, logic):
|
|
367
|
-
pass
|
|
368
|
-
|
|
369
|
-
@abstractmethod
|
|
370
|
-
def negative_binomial(self, id, init_params, logic):
|
|
371
|
-
pass
|
|
372
|
-
|
|
373
|
-
@abstractmethod
|
|
374
|
-
def geometric(self, id, init_params, logic):
|
|
375
|
-
pass
|
|
376
|
-
|
|
377
|
-
@abstractmethod
|
|
378
|
-
def bernoulli(self, id, init_params, logic):
|
|
379
|
-
pass
|
|
380
|
-
|
|
381
|
-
def __str__(self) -> str:
|
|
382
|
-
return 'RandomSampling'
|
|
1104
|
+
def _jax_bernoulli(self, expr, aux):
|
|
1105
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
|
|
1106
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1107
|
+
arg_prob, = expr.args
|
|
1108
|
+
|
|
1109
|
+
# if prob is non-fluent, always use the exact operation
|
|
1110
|
+
if not self.traced.cached_is_fluent(arg_prob):
|
|
1111
|
+
return super()._jax_bernoulli(expr, aux)
|
|
1112
|
+
|
|
1113
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
383
1114
|
|
|
1115
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1116
|
+
|
|
1117
|
+
def _jax_wrapped_distribution_bernoulli_determinized(fls, nfls, params, key):
|
|
1118
|
+
prob, key, err, params = jax_prob(fls, nfls, params, key)
|
|
1119
|
+
sample = prob
|
|
1120
|
+
out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
|
|
1121
|
+
err = err | (out_of_bounds * ERR)
|
|
1122
|
+
return sample, key, err, params
|
|
1123
|
+
return _jax_wrapped_distribution_bernoulli_determinized
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
# ===============================================================================
|
|
1127
|
+
# distribution relaxations - Discrete
|
|
1128
|
+
# ===============================================================================
|
|
1129
|
+
|
|
1130
|
+
# https://arxiv.org/pdf/1611.01144
|
|
1131
|
+
class GumbelSoftmaxDiscrete(JaxRDDLCompilerWithGrad):
|
|
1132
|
+
|
|
1133
|
+
def __init__(self, *args,
|
|
1134
|
+
discrete_softmax_weight: float=10.,
|
|
1135
|
+
discrete_eps: float=1e-14, **kwargs) -> None:
|
|
1136
|
+
super(GumbelSoftmaxDiscrete, self).__init__(*args, **kwargs)
|
|
1137
|
+
self.discrete_softmax_weight = float(discrete_softmax_weight)
|
|
1138
|
+
self.discrete_eps = float(discrete_eps)
|
|
1139
|
+
|
|
1140
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1141
|
+
kwargs = super().get_kwargs()
|
|
1142
|
+
kwargs['discrete_softmax_weight'] = self.discrete_softmax_weight
|
|
1143
|
+
kwargs['discrete_eps'] = self.discrete_eps
|
|
1144
|
+
return kwargs
|
|
1145
|
+
|
|
1146
|
+
def _jax_discrete(self, expr, aux, unnorm):
|
|
1147
|
+
|
|
1148
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1149
|
+
ordered_args = self.traced.cached_sim_info(expr)
|
|
1150
|
+
if not any(self.traced.cached_is_fluent(arg) for arg in ordered_args):
|
|
1151
|
+
return super()._jax_discrete(expr, aux)
|
|
1152
|
+
|
|
1153
|
+
id_ = expr.id
|
|
1154
|
+
aux['params'][id_] = (self.discrete_softmax_weight, self.discrete_eps)
|
|
1155
|
+
aux['overriden'][id_] = __class__.__name__
|
|
384
1156
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
binomial_max_bins: int=100,
|
|
392
|
-
bernoulli_gumbel_softmax: bool=False) -> None:
|
|
393
|
-
'''Creates a new instance of soft random sampling.
|
|
394
|
-
|
|
395
|
-
:param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
|
|
396
|
-
:param poisson_min_cdf: minimum cdf value of Poisson within truncated region
|
|
397
|
-
in order to use Poisson relaxation
|
|
398
|
-
:param poisson_exp_sampling: whether to use Poisson process sampling method
|
|
399
|
-
instead of truncated Gumbel-Softmax
|
|
400
|
-
:param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
|
|
401
|
-
:param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
|
|
402
|
-
Bernoulli samples, or the standard uniform reparameterization instead
|
|
403
|
-
'''
|
|
404
|
-
self.poisson_bins = poisson_max_bins
|
|
405
|
-
self.poisson_min_cdf = poisson_min_cdf
|
|
406
|
-
self.poisson_exp_method = poisson_exp_sampling
|
|
407
|
-
self.binomial_bins = binomial_max_bins
|
|
408
|
-
self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
|
|
409
|
-
|
|
410
|
-
# https://arxiv.org/pdf/1611.01144
|
|
411
|
-
def discrete(self, id, init_params, logic):
|
|
412
|
-
argmax_approx = logic.argmax(id, init_params)
|
|
413
|
-
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
|
|
414
|
-
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
415
|
-
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
416
|
-
return argmax_approx(sample, axis=-1, params=params)
|
|
417
|
-
return _jax_wrapped_calc_discrete_gumbel_softmax
|
|
418
|
-
|
|
419
|
-
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
420
|
-
argmax_approx = logic.argmax(id, init_params)
|
|
421
|
-
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
422
|
-
ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
423
|
-
rate = rate[..., jnp.newaxis]
|
|
424
|
-
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
425
|
-
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
426
|
-
sample = Gumbel01 + log_prob
|
|
427
|
-
return argmax_approx(sample, axis=-1, params=params)
|
|
428
|
-
return _jax_wrapped_calc_poisson_gumbel_softmax
|
|
429
|
-
|
|
430
|
-
# https://arxiv.org/abs/2405.14473
|
|
431
|
-
def _poisson_exponential(self, id, init_params, logic):
|
|
432
|
-
less_approx = logic.less(id, init_params)
|
|
433
|
-
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
434
|
-
Exp1 = random.exponential(
|
|
435
|
-
key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
|
|
436
|
-
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
437
|
-
times = jnp.cumsum(delta_t, axis=0)
|
|
438
|
-
indicator, params = less_approx(times, 1.0, params)
|
|
439
|
-
sample = jnp.sum(indicator, axis=0)
|
|
440
|
-
return sample, params
|
|
441
|
-
return _jax_wrapped_calc_poisson_exponential
|
|
442
|
-
|
|
443
|
-
# normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
|
|
444
|
-
def _poisson_normal_approx(self, logic):
|
|
445
|
-
def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
|
|
446
|
-
normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
|
|
447
|
-
sample = rate + jnp.sqrt(rate) * normal
|
|
448
|
-
return sample, params
|
|
449
|
-
return _jax_wrapped_calc_poisson_normal_approx
|
|
450
|
-
|
|
451
|
-
def poisson(self, id, init_params, logic):
|
|
452
|
-
if self.poisson_exp_method:
|
|
453
|
-
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
454
|
-
id, init_params, logic)
|
|
455
|
-
else:
|
|
456
|
-
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
457
|
-
id, init_params, logic)
|
|
458
|
-
_jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
|
|
459
|
-
|
|
460
|
-
# for small rate use the Poisson process or gumbel-softmax reparameterization
|
|
461
|
-
# for large rate use the normal approximation
|
|
462
|
-
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
463
|
-
if self.poisson_bins > 0:
|
|
464
|
-
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
465
|
-
small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
|
|
466
|
-
small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
|
|
467
|
-
large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
468
|
-
sample = jnp.where(small_rate, small_sample, large_sample)
|
|
469
|
-
return sample, params
|
|
470
|
-
else:
|
|
471
|
-
return _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
472
|
-
return _jax_wrapped_calc_poisson_approx
|
|
473
|
-
|
|
474
|
-
# normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
|
|
475
|
-
def _binomial_normal_approx(self, logic):
|
|
476
|
-
def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
|
|
477
|
-
normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
|
|
478
|
-
mean = trials * prob
|
|
479
|
-
std = jnp.sqrt(trials * prob * (1.0 - prob))
|
|
480
|
-
sample = mean + std * normal
|
|
481
|
-
return sample, params
|
|
482
|
-
return _jax_wrapped_calc_binomial_normal_approx
|
|
483
|
-
|
|
484
|
-
def _binomial_gumbel_softmax(self, id, init_params, logic):
|
|
485
|
-
argmax_approx = logic.argmax(id, init_params)
|
|
486
|
-
def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
|
|
487
|
-
ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
|
|
488
|
-
trials = trials[..., jnp.newaxis]
|
|
489
|
-
prob = prob[..., jnp.newaxis]
|
|
490
|
-
in_support = ks <= trials
|
|
491
|
-
ks = jnp.minimum(ks, trials)
|
|
492
|
-
log_prob = ((scipy.special.gammaln(trials + 1) -
|
|
493
|
-
scipy.special.gammaln(ks + 1) -
|
|
494
|
-
scipy.special.gammaln(trials - ks + 1)) +
|
|
495
|
-
ks * jnp.log(prob + logic.eps) +
|
|
496
|
-
(trials - ks) * jnp.log1p(-prob + logic.eps))
|
|
497
|
-
log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
|
|
498
|
-
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
499
|
-
sample = Gumbel01 + log_prob
|
|
500
|
-
return argmax_approx(sample, axis=-1, params=params)
|
|
501
|
-
return _jax_wrapped_calc_binomial_gumbel_softmax
|
|
502
|
-
|
|
503
|
-
def binomial(self, id, init_params, logic):
|
|
504
|
-
_jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
|
|
505
|
-
_jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
|
|
506
|
-
|
|
507
|
-
# for small trials use the Bernoulli relaxation
|
|
508
|
-
# for large trials use the normal approximation
|
|
509
|
-
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
510
|
-
small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
|
|
511
|
-
small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
|
|
512
|
-
large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
|
|
513
|
-
sample = jnp.where(small_trials, small_sample, large_sample)
|
|
514
|
-
return sample, params
|
|
515
|
-
return _jax_wrapped_calc_binomial_approx
|
|
516
|
-
|
|
517
|
-
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
518
|
-
def negative_binomial(self, id, init_params, logic):
|
|
519
|
-
poisson_approx = self.poisson(id, init_params, logic)
|
|
520
|
-
def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
|
|
1157
|
+
jax_probs = [self._jax(arg, aux) for arg in ordered_args]
|
|
1158
|
+
prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
|
|
1159
|
+
|
|
1160
|
+
def _jax_wrapped_distribution_discrete_gumbel_softmax(fls, nfls, params, key):
|
|
1161
|
+
w, eps = params[id_]
|
|
1162
|
+
prob, key, err, params = prob_fn(fls, nfls, params, key)
|
|
521
1163
|
key, subkey = random.split(key)
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
return less_approx(U, prob, params)
|
|
544
|
-
return _jax_wrapped_calc_bernoulli_uniform
|
|
545
|
-
|
|
546
|
-
def _bernoulli_gumbel_softmax(self, id, init_params, logic):
|
|
547
|
-
discrete_approx = self.discrete(id, init_params, logic)
|
|
548
|
-
def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
|
|
549
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
550
|
-
return discrete_approx(key, prob, params)
|
|
551
|
-
return _jax_wrapped_calc_bernoulli_gumbel_softmax
|
|
552
|
-
|
|
553
|
-
def bernoulli(self, id, init_params, logic):
|
|
554
|
-
if self.bernoulli_gumbel_softmax:
|
|
555
|
-
return self._bernoulli_gumbel_softmax(id, init_params, logic)
|
|
556
|
-
else:
|
|
557
|
-
return self._bernoulli_uniform(id, init_params, logic)
|
|
558
|
-
|
|
559
|
-
def __str__(self) -> str:
|
|
560
|
-
return 'SoftRandomSampling'
|
|
561
|
-
|
|
1164
|
+
g = random.gumbel(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1165
|
+
sample = SoftmaxArgmax.soft_argmax(g + jnp.log(prob + eps), w=w, axes=-1)
|
|
1166
|
+
err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
|
|
1167
|
+
return sample, key, err, params
|
|
1168
|
+
return _jax_wrapped_distribution_discrete_gumbel_softmax
|
|
1169
|
+
|
|
1170
|
+
def _jax_discrete_pvar(self, expr, aux, unnorm):
|
|
1171
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1172
|
+
_, args = expr.args
|
|
1173
|
+
arg, = args
|
|
1174
|
+
|
|
1175
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1176
|
+
if not self.traced.cached_is_fluent(arg):
|
|
1177
|
+
return super()._jax_discrete_pvar(expr, aux)
|
|
1178
|
+
|
|
1179
|
+
id_ = expr.id
|
|
1180
|
+
aux['params'][id_] = (self.discrete_softmax_weight, self.discrete_eps)
|
|
1181
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1182
|
+
|
|
1183
|
+
jax_probs = self._jax(arg, aux)
|
|
1184
|
+
prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
|
|
562
1185
|
|
|
563
|
-
|
|
564
|
-
|
|
1186
|
+
def _jax_wrapped_distribution_discrete_pvar_gumbel_softmax(fls, nfls, params, key):
|
|
1187
|
+
w, eps = params[id_]
|
|
1188
|
+
prob, key, err, params = prob_fn(fls, nfls, params, key)
|
|
1189
|
+
key, subkey = random.split(key)
|
|
1190
|
+
g = random.gumbel(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1191
|
+
sample = SoftmaxArgmax.soft_argmax(g + jnp.log(prob + eps), w=w, axes=-1)
|
|
1192
|
+
err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
|
|
1193
|
+
return sample, key, err, params
|
|
1194
|
+
return _jax_wrapped_distribution_discrete_pvar_gumbel_softmax
|
|
565
1195
|
|
|
566
|
-
@staticmethod
|
|
567
|
-
def _jax_wrapped_calc_discrete_determinized(key, prob, params):
|
|
568
|
-
literals = enumerate_literals(jnp.shape(prob), axis=-1)
|
|
569
|
-
sample = jnp.sum(literals * prob, axis=-1)
|
|
570
|
-
return sample, params
|
|
571
|
-
|
|
572
|
-
def discrete(self, id, init_params, logic):
|
|
573
|
-
return self._jax_wrapped_calc_discrete_determinized
|
|
574
|
-
|
|
575
|
-
@staticmethod
|
|
576
|
-
def _jax_wrapped_calc_poisson_determinized(key, rate, params):
|
|
577
|
-
return rate, params
|
|
578
1196
|
|
|
579
|
-
|
|
580
|
-
return self._jax_wrapped_calc_poisson_determinized
|
|
581
|
-
|
|
582
|
-
@staticmethod
|
|
583
|
-
def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
|
|
584
|
-
sample = trials * prob
|
|
585
|
-
return sample, params
|
|
586
|
-
|
|
587
|
-
def binomial(self, id, init_params, logic):
|
|
588
|
-
return self._jax_wrapped_calc_binomial_determinized
|
|
1197
|
+
class DeterminizedDiscrete(JaxRDDLCompilerWithGrad):
|
|
589
1198
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
sample = trials * ((1.0 / prob) - 1.0)
|
|
593
|
-
return sample, params
|
|
1199
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
1200
|
+
super(DeterminizedDiscrete, self).__init__(*args, **kwargs)
|
|
594
1201
|
|
|
595
|
-
def
|
|
596
|
-
return
|
|
597
|
-
|
|
598
|
-
@staticmethod
|
|
599
|
-
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
600
|
-
sample = 1.0 / prob
|
|
601
|
-
return sample, params
|
|
602
|
-
|
|
603
|
-
def geometric(self, id, init_params, logic):
|
|
604
|
-
return self._jax_wrapped_calc_geometric_determinized
|
|
605
|
-
|
|
606
|
-
@staticmethod
|
|
607
|
-
def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
|
|
608
|
-
sample = prob
|
|
609
|
-
return sample, params
|
|
610
|
-
|
|
611
|
-
def bernoulli(self, id, init_params, logic):
|
|
612
|
-
return self._jax_wrapped_calc_bernoulli_determinized
|
|
1202
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1203
|
+
return super().get_kwargs()
|
|
613
1204
|
|
|
614
|
-
def
|
|
615
|
-
|
|
616
|
-
|
|
1205
|
+
def _jax_discrete(self, expr, aux, unnorm):
|
|
1206
|
+
|
|
1207
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1208
|
+
ordered_args = self.traced.cached_sim_info(expr)
|
|
1209
|
+
if not any(self.traced.cached_is_fluent(arg) for arg in ordered_args):
|
|
1210
|
+
return super()._jax_discrete(expr, aux)
|
|
1211
|
+
|
|
1212
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1213
|
+
|
|
1214
|
+
jax_probs = [self._jax(arg, aux) for arg in ordered_args]
|
|
1215
|
+
prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
|
|
1216
|
+
|
|
1217
|
+
def _jax_wrapped_distribution_discrete_determinized(fls, nfls, params, key):
|
|
1218
|
+
prob, key, err, params = prob_fn(fls, nfls, params, key)
|
|
1219
|
+
literals = enumerate_literals(jnp.shape(prob), axis=-1)
|
|
1220
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
1221
|
+
err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
|
|
1222
|
+
return sample, key, err, params
|
|
1223
|
+
return _jax_wrapped_distribution_discrete_determinized
|
|
1224
|
+
|
|
1225
|
+
def _jax_discrete_pvar(self, expr, aux, unnorm):
|
|
1226
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1227
|
+
_, args = expr.args
|
|
1228
|
+
arg, = args
|
|
1229
|
+
|
|
1230
|
+
# if all probabilities are non-fluent, then always sample exact
|
|
1231
|
+
if not self.traced.cached_is_fluent(arg):
|
|
1232
|
+
return super()._jax_discrete_pvar(expr, aux)
|
|
1233
|
+
|
|
1234
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1235
|
+
|
|
1236
|
+
jax_probs = self._jax(arg, aux)
|
|
1237
|
+
prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
|
|
1238
|
+
|
|
1239
|
+
def _jax_wrapped_distribution_discrete_pvar_determinized(fls, nfls, params, key):
|
|
1240
|
+
prob, key, err, params = prob_fn(fls, nfls, params, key)
|
|
1241
|
+
literals = enumerate_literals(jnp.shape(prob), axis=-1)
|
|
1242
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
1243
|
+
err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
|
|
1244
|
+
return sample, key, err, params
|
|
1245
|
+
return _jax_wrapped_distribution_discrete_pvar_determinized
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
# ===============================================================================
|
|
1249
|
+
# distribution relaxations - Binomial
|
|
1250
|
+
# ===============================================================================
|
|
1251
|
+
|
|
1252
|
+
class GumbelSoftmaxBinomial(JaxRDDLCompilerWithGrad):
|
|
1253
|
+
|
|
1254
|
+
def __init__(self, *args,
|
|
1255
|
+
binomial_nbins: int=100,
|
|
1256
|
+
binomial_softmax_weight: float=10.,
|
|
1257
|
+
binomial_eps: float=1e-14, **kwargs) -> None:
|
|
1258
|
+
super(GumbelSoftmaxBinomial, self).__init__(*args, **kwargs)
|
|
1259
|
+
self.binomial_nbins = binomial_nbins
|
|
1260
|
+
self.binomial_softmax_weight = float(binomial_softmax_weight)
|
|
1261
|
+
self.binomial_eps = float(binomial_eps)
|
|
1262
|
+
|
|
1263
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1264
|
+
kwargs = super().get_kwargs()
|
|
1265
|
+
kwargs['binomial_nbins'] = self.binomial_nbins
|
|
1266
|
+
kwargs['binomial_softmax_weight'] = self.binomial_softmax_weight
|
|
1267
|
+
kwargs['binomial_eps'] = self.binomial_eps
|
|
1268
|
+
return kwargs
|
|
617
1269
|
|
|
618
|
-
#
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
1270
|
+
# normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
|
|
1271
|
+
@staticmethod
|
|
1272
|
+
def normal_approx_to_binomial(key: random.PRNGKey,
|
|
1273
|
+
trials: jnp.ndarray, prob: jnp.ndarray) -> jnp.ndarray:
|
|
1274
|
+
normal = random.normal(key=key, shape=jnp.shape(trials), dtype=prob.dtype)
|
|
1275
|
+
mean = trials * prob
|
|
1276
|
+
std = jnp.sqrt(trials * jnp.clip(prob * (1.0 - prob), 0.0, 1.0))
|
|
1277
|
+
return mean + std * normal
|
|
1278
|
+
|
|
1279
|
+
def gumbel_softmax_approx_to_binomial(self, key: random.PRNGKey,
|
|
1280
|
+
trials: jnp.ndarray, prob: jnp.ndarray,
|
|
1281
|
+
w: float, eps: float):
|
|
1282
|
+
ks = jnp.arange(self.binomial_nbins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
|
|
1283
|
+
trials = trials[..., jnp.newaxis]
|
|
1284
|
+
prob = prob[..., jnp.newaxis]
|
|
1285
|
+
in_support = ks <= trials
|
|
1286
|
+
ks = jnp.minimum(ks, trials)
|
|
1287
|
+
log_prob = ((scipy.special.gammaln(trials + 1) -
|
|
1288
|
+
scipy.special.gammaln(ks + 1) -
|
|
1289
|
+
scipy.special.gammaln(trials - ks + 1)) +
|
|
1290
|
+
ks * jnp.log(prob + eps) +
|
|
1291
|
+
(trials - ks) * jnp.log1p(-prob + eps))
|
|
1292
|
+
log_prob = jnp.where(in_support, log_prob, jnp.log(eps))
|
|
1293
|
+
g = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=prob.dtype)
|
|
1294
|
+
return SoftmaxArgmax.soft_argmax(g + log_prob, w=w, axes=-1)
|
|
1295
|
+
|
|
1296
|
+
def _jax_binomial(self, expr, aux):
|
|
1297
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1298
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1299
|
+
arg_trials, arg_prob = expr.args
|
|
1300
|
+
|
|
1301
|
+
# if prob is non-fluent, always use the exact operation
|
|
1302
|
+
if (not self.traced.cached_is_fluent(arg_trials) and
|
|
1303
|
+
not self.traced.cached_is_fluent(arg_prob)):
|
|
1304
|
+
return super()._jax_binomial(expr, aux)
|
|
1305
|
+
|
|
1306
|
+
id_ = expr.id
|
|
1307
|
+
aux['params'][id_] = (self.binomial_softmax_weight, self.binomial_eps)
|
|
1308
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1309
|
+
|
|
1310
|
+
# recursively compile arguments
|
|
1311
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1312
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
623
1313
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
@abstractmethod
|
|
632
|
-
def switch(self, id, init_params):
|
|
633
|
-
pass
|
|
1314
|
+
def _jax_wrapped_distribution_binomial_gumbel_softmax(fls, nfls, params, key):
|
|
1315
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1316
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1317
|
+
key, subkey = random.split(key)
|
|
1318
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1319
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
634
1320
|
|
|
1321
|
+
# use the gumbel-softmax trick for small population size
|
|
1322
|
+
# use the normal approximation for large population size
|
|
1323
|
+
sample = jnp.where(
|
|
1324
|
+
jax.lax.stop_gradient(trials < self.binomial_nbins),
|
|
1325
|
+
self.gumbel_softmax_approx_to_binomial(subkey, trials, prob, *params[id_]),
|
|
1326
|
+
self.normal_approx_to_binomial(subkey, trials, prob)
|
|
1327
|
+
)
|
|
635
1328
|
|
|
636
|
-
|
|
637
|
-
|
|
1329
|
+
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1330
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
|
|
1331
|
+
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1332
|
+
return sample, key, err, params
|
|
1333
|
+
return _jax_wrapped_distribution_binomial_gumbel_softmax
|
|
638
1334
|
|
|
639
|
-
def __init__(self, weight: float=10.0) -> None:
|
|
640
|
-
self.weight = float(weight)
|
|
641
|
-
|
|
642
|
-
@staticmethod
|
|
643
|
-
def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
|
|
644
|
-
sample = c * a + (1.0 - c) * b
|
|
645
|
-
return sample, params
|
|
646
|
-
|
|
647
|
-
def if_then_else(self, id, init_params):
|
|
648
|
-
return self._jax_wrapped_calc_if_then_else_soft
|
|
649
|
-
|
|
650
|
-
def switch(self, id, init_params):
|
|
651
|
-
id_ = str(id)
|
|
652
|
-
init_params[id_] = self.weight
|
|
653
|
-
def _jax_wrapped_calc_switch_soft(pred, cases, params):
|
|
654
|
-
literals = enumerate_literals(jnp.shape(cases), axis=0)
|
|
655
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
|
|
656
|
-
proximity = -jnp.square(pred - literals)
|
|
657
|
-
softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
|
|
658
|
-
sample = jnp.sum(cases * softcase, axis=0)
|
|
659
|
-
return sample, params
|
|
660
|
-
return _jax_wrapped_calc_switch_soft
|
|
661
|
-
|
|
662
|
-
def __str__(self) -> str:
|
|
663
|
-
return f'Soft control flow with weight {self.weight}'
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
# ===========================================================================
|
|
667
|
-
# LOGIC
|
|
668
|
-
# - exact logic
|
|
669
|
-
# - fuzzy logic
|
|
670
|
-
#
|
|
671
|
-
# ===========================================================================
|
|
672
1335
|
|
|
1336
|
+
class DeterminizedBinomial(JaxRDDLCompilerWithGrad):
|
|
673
1337
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
def __init__(self, use64bit: bool=False) -> None:
|
|
678
|
-
self.set_use64bit(use64bit)
|
|
679
|
-
|
|
680
|
-
def summarize_hyperparameters(self) -> str:
|
|
681
|
-
return (f'model relaxation:\n'
|
|
682
|
-
f' use_64_bit ={self.use64bit}')
|
|
683
|
-
|
|
684
|
-
def set_use64bit(self, use64bit: bool) -> None:
|
|
685
|
-
'''Toggles whether or not the JAX system will use 64 bit precision.'''
|
|
686
|
-
self.use64bit = use64bit
|
|
687
|
-
if use64bit:
|
|
688
|
-
self.REAL = jnp.float64
|
|
689
|
-
self.INT = jnp.int64
|
|
690
|
-
jax.config.update('jax_enable_x64', True)
|
|
691
|
-
else:
|
|
692
|
-
self.REAL = jnp.float32
|
|
693
|
-
self.INT = jnp.int32
|
|
694
|
-
jax.config.update('jax_enable_x64', False)
|
|
695
|
-
|
|
696
|
-
@staticmethod
|
|
697
|
-
def wrap_logic(func):
|
|
698
|
-
def exact_func(id, init_params):
|
|
699
|
-
return func
|
|
700
|
-
return exact_func
|
|
701
|
-
|
|
702
|
-
def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
|
|
703
|
-
'''Returns a dictionary of all operators in the current logic.'''
|
|
704
|
-
return {
|
|
705
|
-
'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
|
|
706
|
-
'arithmetic': {
|
|
707
|
-
'+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
708
|
-
'-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
709
|
-
'*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
710
|
-
'/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
711
|
-
},
|
|
712
|
-
'relational': {
|
|
713
|
-
'>=': self.greater_equal,
|
|
714
|
-
'<=': self.less_equal,
|
|
715
|
-
'<': self.less,
|
|
716
|
-
'>': self.greater,
|
|
717
|
-
'==': self.equal,
|
|
718
|
-
'~=': self.not_equal
|
|
719
|
-
},
|
|
720
|
-
'logical_not': self.logical_not,
|
|
721
|
-
'logical': {
|
|
722
|
-
'^': self.logical_and,
|
|
723
|
-
'&': self.logical_and,
|
|
724
|
-
'|': self.logical_or,
|
|
725
|
-
'~': self.xor,
|
|
726
|
-
'=>': self.implies,
|
|
727
|
-
'<=>': self.equiv
|
|
728
|
-
},
|
|
729
|
-
'aggregation': {
|
|
730
|
-
'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
731
|
-
'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
732
|
-
'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
733
|
-
'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
734
|
-
'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
735
|
-
'forall': self.forall,
|
|
736
|
-
'exists': self.exists,
|
|
737
|
-
'argmin': self.argmin,
|
|
738
|
-
'argmax': self.argmax
|
|
739
|
-
},
|
|
740
|
-
'unary': {
|
|
741
|
-
'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
742
|
-
'sgn': self.sgn,
|
|
743
|
-
'round': self.round,
|
|
744
|
-
'floor': self.floor,
|
|
745
|
-
'ceil': self.ceil,
|
|
746
|
-
'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
747
|
-
'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
748
|
-
'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
749
|
-
'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
750
|
-
'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
751
|
-
'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
752
|
-
'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
753
|
-
'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
754
|
-
'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
755
|
-
'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
756
|
-
'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
757
|
-
'sqrt': self.sqrt,
|
|
758
|
-
'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
759
|
-
'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
760
|
-
},
|
|
761
|
-
'binary': {
|
|
762
|
-
'div': self.div,
|
|
763
|
-
'mod': self.mod,
|
|
764
|
-
'fmod': self.mod,
|
|
765
|
-
'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
766
|
-
'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
767
|
-
'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
768
|
-
'log': self.wrap_logic(ExactLogic.exact_binary_log),
|
|
769
|
-
'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
770
|
-
},
|
|
771
|
-
'control': {
|
|
772
|
-
'if': self.control_if,
|
|
773
|
-
'switch': self.control_switch
|
|
774
|
-
},
|
|
775
|
-
'sampling': {
|
|
776
|
-
'Bernoulli': self.bernoulli,
|
|
777
|
-
'Discrete': self.discrete,
|
|
778
|
-
'Poisson': self.poisson,
|
|
779
|
-
'Geometric': self.geometric,
|
|
780
|
-
'Binomial': self.binomial,
|
|
781
|
-
'NegativeBinomial': self.negative_binomial
|
|
782
|
-
}
|
|
783
|
-
}
|
|
784
|
-
|
|
785
|
-
# ===========================================================================
|
|
786
|
-
# logical operators
|
|
787
|
-
# ===========================================================================
|
|
788
|
-
|
|
789
|
-
@abstractmethod
|
|
790
|
-
def logical_and(self, id, init_params):
|
|
791
|
-
pass
|
|
792
|
-
|
|
793
|
-
@abstractmethod
|
|
794
|
-
def logical_not(self, id, init_params):
|
|
795
|
-
pass
|
|
796
|
-
|
|
797
|
-
@abstractmethod
|
|
798
|
-
def logical_or(self, id, init_params):
|
|
799
|
-
pass
|
|
800
|
-
|
|
801
|
-
@abstractmethod
|
|
802
|
-
def xor(self, id, init_params):
|
|
803
|
-
pass
|
|
804
|
-
|
|
805
|
-
@abstractmethod
|
|
806
|
-
def implies(self, id, init_params):
|
|
807
|
-
pass
|
|
808
|
-
|
|
809
|
-
@abstractmethod
|
|
810
|
-
def equiv(self, id, init_params):
|
|
811
|
-
pass
|
|
812
|
-
|
|
813
|
-
@abstractmethod
|
|
814
|
-
def forall(self, id, init_params):
|
|
815
|
-
pass
|
|
816
|
-
|
|
817
|
-
@abstractmethod
|
|
818
|
-
def exists(self, id, init_params):
|
|
819
|
-
pass
|
|
820
|
-
|
|
821
|
-
# ===========================================================================
|
|
822
|
-
# comparison operators
|
|
823
|
-
# ===========================================================================
|
|
824
|
-
|
|
825
|
-
@abstractmethod
|
|
826
|
-
def greater_equal(self, id, init_params):
|
|
827
|
-
pass
|
|
828
|
-
|
|
829
|
-
@abstractmethod
|
|
830
|
-
def greater(self, id, init_params):
|
|
831
|
-
pass
|
|
832
|
-
|
|
833
|
-
@abstractmethod
|
|
834
|
-
def less_equal(self, id, init_params):
|
|
835
|
-
pass
|
|
836
|
-
|
|
837
|
-
@abstractmethod
|
|
838
|
-
def less(self, id, init_params):
|
|
839
|
-
pass
|
|
840
|
-
|
|
841
|
-
@abstractmethod
|
|
842
|
-
def equal(self, id, init_params):
|
|
843
|
-
pass
|
|
844
|
-
|
|
845
|
-
@abstractmethod
|
|
846
|
-
def not_equal(self, id, init_params):
|
|
847
|
-
pass
|
|
848
|
-
|
|
849
|
-
# ===========================================================================
|
|
850
|
-
# special functions
|
|
851
|
-
# ===========================================================================
|
|
852
|
-
|
|
853
|
-
@abstractmethod
|
|
854
|
-
def sgn(self, id, init_params):
|
|
855
|
-
pass
|
|
856
|
-
|
|
857
|
-
@abstractmethod
|
|
858
|
-
def floor(self, id, init_params):
|
|
859
|
-
pass
|
|
860
|
-
|
|
861
|
-
@abstractmethod
|
|
862
|
-
def round(self, id, init_params):
|
|
863
|
-
pass
|
|
864
|
-
|
|
865
|
-
@abstractmethod
|
|
866
|
-
def ceil(self, id, init_params):
|
|
867
|
-
pass
|
|
868
|
-
|
|
869
|
-
@abstractmethod
|
|
870
|
-
def div(self, id, init_params):
|
|
871
|
-
pass
|
|
872
|
-
|
|
873
|
-
@abstractmethod
|
|
874
|
-
def mod(self, id, init_params):
|
|
875
|
-
pass
|
|
876
|
-
|
|
877
|
-
@abstractmethod
|
|
878
|
-
def sqrt(self, id, init_params):
|
|
879
|
-
pass
|
|
880
|
-
|
|
881
|
-
# ===========================================================================
|
|
882
|
-
# indexing
|
|
883
|
-
# ===========================================================================
|
|
884
|
-
|
|
885
|
-
@abstractmethod
|
|
886
|
-
def argmax(self, id, init_params):
|
|
887
|
-
pass
|
|
888
|
-
|
|
889
|
-
@abstractmethod
|
|
890
|
-
def argmin(self, id, init_params):
|
|
891
|
-
pass
|
|
892
|
-
|
|
893
|
-
# ===========================================================================
|
|
894
|
-
# control flow
|
|
895
|
-
# ===========================================================================
|
|
896
|
-
|
|
897
|
-
@abstractmethod
|
|
898
|
-
def control_if(self, id, init_params):
|
|
899
|
-
pass
|
|
900
|
-
|
|
901
|
-
@abstractmethod
|
|
902
|
-
def control_switch(self, id, init_params):
|
|
903
|
-
pass
|
|
904
|
-
|
|
905
|
-
# ===========================================================================
|
|
906
|
-
# random variables
|
|
907
|
-
# ===========================================================================
|
|
908
|
-
|
|
909
|
-
@abstractmethod
|
|
910
|
-
def discrete(self, id, init_params):
|
|
911
|
-
pass
|
|
912
|
-
|
|
913
|
-
@abstractmethod
|
|
914
|
-
def bernoulli(self, id, init_params):
|
|
915
|
-
pass
|
|
916
|
-
|
|
917
|
-
@abstractmethod
|
|
918
|
-
def poisson(self, id, init_params):
|
|
919
|
-
pass
|
|
920
|
-
|
|
921
|
-
@abstractmethod
|
|
922
|
-
def geometric(self, id, init_params):
|
|
923
|
-
pass
|
|
924
|
-
|
|
925
|
-
@abstractmethod
|
|
926
|
-
def binomial(self, id, init_params):
|
|
927
|
-
pass
|
|
1338
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
1339
|
+
super(DeterminizedBinomial, self).__init__(*args, **kwargs)
|
|
928
1340
|
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
pass
|
|
1341
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1342
|
+
return super().get_kwargs()
|
|
932
1343
|
|
|
1344
|
+
def _jax_binomial(self, expr, aux):
|
|
1345
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1346
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1347
|
+
arg_trials, arg_prob = expr.args
|
|
933
1348
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
return _jax_wrapped_calc_binary_function_exact
|
|
948
|
-
|
|
949
|
-
@staticmethod
|
|
950
|
-
def exact_aggregation(op):
|
|
951
|
-
def _jax_wrapped_calc_aggregation_exact(x, axis, params):
|
|
952
|
-
return op(x, axis=axis), params
|
|
953
|
-
return _jax_wrapped_calc_aggregation_exact
|
|
954
|
-
|
|
955
|
-
# ===========================================================================
|
|
956
|
-
# logical operators
|
|
957
|
-
# ===========================================================================
|
|
958
|
-
|
|
959
|
-
def logical_and(self, id, init_params):
|
|
960
|
-
return self.exact_binary_function(jnp.logical_and)
|
|
961
|
-
|
|
962
|
-
def logical_not(self, id, init_params):
|
|
963
|
-
return self.exact_unary_function(jnp.logical_not)
|
|
964
|
-
|
|
965
|
-
def logical_or(self, id, init_params):
|
|
966
|
-
return self.exact_binary_function(jnp.logical_or)
|
|
967
|
-
|
|
968
|
-
def xor(self, id, init_params):
|
|
969
|
-
return self.exact_binary_function(jnp.logical_xor)
|
|
970
|
-
|
|
971
|
-
@staticmethod
|
|
972
|
-
def _jax_wrapped_calc_implies_exact(x, y, params):
|
|
973
|
-
return jnp.logical_or(jnp.logical_not(x), y), params
|
|
974
|
-
|
|
975
|
-
def implies(self, id, init_params):
|
|
976
|
-
return self._jax_wrapped_calc_implies_exact
|
|
977
|
-
|
|
978
|
-
def equiv(self, id, init_params):
|
|
979
|
-
return self.exact_binary_function(jnp.equal)
|
|
980
|
-
|
|
981
|
-
def forall(self, id, init_params):
|
|
982
|
-
return self.exact_aggregation(jnp.all)
|
|
983
|
-
|
|
984
|
-
def exists(self, id, init_params):
|
|
985
|
-
return self.exact_aggregation(jnp.any)
|
|
986
|
-
|
|
987
|
-
# ===========================================================================
|
|
988
|
-
# comparison operators
|
|
989
|
-
# ===========================================================================
|
|
990
|
-
|
|
991
|
-
def greater_equal(self, id, init_params):
|
|
992
|
-
return self.exact_binary_function(jnp.greater_equal)
|
|
993
|
-
|
|
994
|
-
def greater(self, id, init_params):
|
|
995
|
-
return self.exact_binary_function(jnp.greater)
|
|
996
|
-
|
|
997
|
-
def less_equal(self, id, init_params):
|
|
998
|
-
return self.exact_binary_function(jnp.less_equal)
|
|
999
|
-
|
|
1000
|
-
def less(self, id, init_params):
|
|
1001
|
-
return self.exact_binary_function(jnp.less)
|
|
1002
|
-
|
|
1003
|
-
def equal(self, id, init_params):
|
|
1004
|
-
return self.exact_binary_function(jnp.equal)
|
|
1005
|
-
|
|
1006
|
-
def not_equal(self, id, init_params):
|
|
1007
|
-
return self.exact_binary_function(jnp.not_equal)
|
|
1008
|
-
|
|
1009
|
-
# ===========================================================================
|
|
1010
|
-
# special functions
|
|
1011
|
-
# ===========================================================================
|
|
1012
|
-
|
|
1013
|
-
@staticmethod
|
|
1014
|
-
def exact_binary_log(x, y, params):
|
|
1015
|
-
return jnp.log(x) / jnp.log(y), params
|
|
1016
|
-
|
|
1017
|
-
def sgn(self, id, init_params):
|
|
1018
|
-
return self.exact_unary_function(jnp.sign)
|
|
1019
|
-
|
|
1020
|
-
def floor(self, id, init_params):
|
|
1021
|
-
return self.exact_unary_function(jnp.floor)
|
|
1022
|
-
|
|
1023
|
-
def round(self, id, init_params):
|
|
1024
|
-
return self.exact_unary_function(jnp.round)
|
|
1025
|
-
|
|
1026
|
-
def ceil(self, id, init_params):
|
|
1027
|
-
return self.exact_unary_function(jnp.ceil)
|
|
1028
|
-
|
|
1029
|
-
def div(self, id, init_params):
|
|
1030
|
-
return self.exact_binary_function(jnp.floor_divide)
|
|
1031
|
-
|
|
1032
|
-
def mod(self, id, init_params):
|
|
1033
|
-
return self.exact_binary_function(jnp.mod)
|
|
1034
|
-
|
|
1035
|
-
def sqrt(self, id, init_params):
|
|
1036
|
-
return self.exact_unary_function(jnp.sqrt)
|
|
1037
|
-
|
|
1038
|
-
# ===========================================================================
|
|
1039
|
-
# indexing
|
|
1040
|
-
# ===========================================================================
|
|
1041
|
-
|
|
1042
|
-
def argmax(self, id, init_params):
|
|
1043
|
-
return self.exact_aggregation(jnp.argmax)
|
|
1044
|
-
|
|
1045
|
-
def argmin(self, id, init_params):
|
|
1046
|
-
return self.exact_aggregation(jnp.argmin)
|
|
1047
|
-
|
|
1048
|
-
# ===========================================================================
|
|
1049
|
-
# control flow
|
|
1050
|
-
# ===========================================================================
|
|
1051
|
-
|
|
1052
|
-
@staticmethod
|
|
1053
|
-
def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
|
|
1054
|
-
return jnp.where(c > 0.5, a, b), params
|
|
1055
|
-
|
|
1056
|
-
def control_if(self, id, init_params):
|
|
1057
|
-
return self._jax_wrapped_calc_if_then_else_exact
|
|
1058
|
-
|
|
1059
|
-
def control_switch(self, id, init_params):
|
|
1060
|
-
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
-
pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1062
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
-
assert sample.shape[0] == 1
|
|
1064
|
-
return sample[0, ...], params
|
|
1065
|
-
return _jax_wrapped_calc_switch_exact
|
|
1066
|
-
|
|
1067
|
-
# ===========================================================================
|
|
1068
|
-
# random variables
|
|
1069
|
-
# ===========================================================================
|
|
1070
|
-
|
|
1071
|
-
@staticmethod
|
|
1072
|
-
def _jax_wrapped_calc_discrete_exact(key, prob, params):
|
|
1073
|
-
sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
1074
|
-
return sample, params
|
|
1075
|
-
|
|
1076
|
-
def discrete(self, id, init_params):
|
|
1077
|
-
return self._jax_wrapped_calc_discrete_exact
|
|
1078
|
-
|
|
1079
|
-
@staticmethod
|
|
1080
|
-
def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
|
|
1081
|
-
return random.bernoulli(key, prob), params
|
|
1082
|
-
|
|
1083
|
-
def bernoulli(self, id, init_params):
|
|
1084
|
-
return self._jax_wrapped_calc_bernoulli_exact
|
|
1085
|
-
|
|
1086
|
-
def poisson(self, id, init_params):
|
|
1087
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
1088
|
-
sample = random.poisson(key=key, lam=rate, dtype=self.INT)
|
|
1089
|
-
return sample, params
|
|
1090
|
-
return _jax_wrapped_calc_poisson_exact
|
|
1091
|
-
|
|
1092
|
-
def geometric(self, id, init_params):
|
|
1093
|
-
def _jax_wrapped_calc_geometric_exact(key, prob, params):
|
|
1094
|
-
sample = random.geometric(key=key, p=prob, dtype=self.INT)
|
|
1095
|
-
return sample, params
|
|
1096
|
-
return _jax_wrapped_calc_geometric_exact
|
|
1097
|
-
|
|
1098
|
-
def binomial(self, id, init_params):
|
|
1099
|
-
def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
|
|
1349
|
+
# if prob is non-fluent, always use the exact operation
|
|
1350
|
+
if (not self.traced.cached_is_fluent(arg_trials) and
|
|
1351
|
+
not self.traced.cached_is_fluent(arg_prob)):
|
|
1352
|
+
return super()._jax_binomial(expr, aux)
|
|
1353
|
+
|
|
1354
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1355
|
+
|
|
1356
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1357
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1358
|
+
|
|
1359
|
+
def _jax_wrapped_distribution_binomial_determinized(fls, nfls, params, key):
|
|
1360
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1361
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1100
1362
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1101
1363
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1102
|
-
sample =
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1364
|
+
sample = trials * prob
|
|
1365
|
+
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1366
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
|
|
1367
|
+
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1368
|
+
return sample, key, err, params
|
|
1369
|
+
return _jax_wrapped_distribution_binomial_determinized
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
# ===============================================================================
|
|
1373
|
+
# distribution relaxations - Poisson and NegativeBinomial
|
|
1374
|
+
# ===============================================================================
|
|
1375
|
+
|
|
1376
|
+
class ExponentialPoisson(JaxRDDLCompilerWithGrad):
|
|
1377
|
+
|
|
1378
|
+
def __init__(self, *args,
|
|
1379
|
+
poisson_nbins: int=100,
|
|
1380
|
+
poisson_comparison_weight: float=10.,
|
|
1381
|
+
poisson_min_cdf: float=0.999, **kwargs) -> None:
|
|
1382
|
+
super(ExponentialPoisson, self).__init__(*args, **kwargs)
|
|
1383
|
+
self.poisson_nbins = poisson_nbins
|
|
1384
|
+
self.poisson_comparison_weight = float(poisson_comparison_weight)
|
|
1385
|
+
self.poisson_min_cdf = float(poisson_min_cdf)
|
|
1386
|
+
|
|
1387
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1388
|
+
kwargs = super().get_kwargs()
|
|
1389
|
+
kwargs['poisson_nbins'] = self.poisson_nbins
|
|
1390
|
+
kwargs['poisson_comparison_weight'] = self.poisson_comparison_weight
|
|
1391
|
+
kwargs['poisson_min_cdf'] = self.poisson_min_cdf
|
|
1392
|
+
return kwargs
|
|
1393
|
+
|
|
1394
|
+
def exponential_approx_to_poisson(self, key: random.PRNGKey,
|
|
1395
|
+
rate: jnp.ndarray, w: float) -> jnp.ndarray:
|
|
1396
|
+
exp = random.exponential(
|
|
1397
|
+
key=key, shape=(self.poisson_nbins,) + jnp.shape(rate), dtype=rate.dtype)
|
|
1398
|
+
delta_t = exp / rate[jnp.newaxis, ...]
|
|
1399
|
+
times = jnp.cumsum(delta_t, axis=0)
|
|
1400
|
+
indicator = stable_sigmoid(w * (1. - times))
|
|
1401
|
+
return jnp.sum(indicator, axis=0)
|
|
1402
|
+
|
|
1403
|
+
def branched_approx_to_poisson(self, key: random.PRNGKey,
|
|
1404
|
+
rate: jnp.ndarray, w: float, min_cdf: float) -> jnp.ndarray:
|
|
1405
|
+
cuml_prob = scipy.stats.poisson.cdf(self.poisson_nbins, rate)
|
|
1406
|
+
z = random.normal(key=key, shape=jnp.shape(rate), dtype=rate.dtype)
|
|
1407
|
+
return jnp.where(
|
|
1408
|
+
jax.lax.stop_gradient(cuml_prob >= min_cdf),
|
|
1409
|
+
self.exponential_approx_to_poisson(key, rate, w),
|
|
1410
|
+
rate + jnp.sqrt(rate) * z
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
def _jax_poisson(self, expr, aux):
|
|
1414
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1415
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1416
|
+
arg_rate, = expr.args
|
|
1417
|
+
|
|
1418
|
+
# if rate is non-fluent, always use the exact operation
|
|
1419
|
+
if not self.traced.cached_is_fluent(arg_rate):
|
|
1420
|
+
return super()._jax_poisson(expr, aux)
|
|
1421
|
+
|
|
1422
|
+
id_ = expr.id
|
|
1423
|
+
aux['params'][id_] = (self.poisson_comparison_weight, self.poisson_min_cdf)
|
|
1424
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1425
|
+
|
|
1426
|
+
jax_rate = self._jax(arg_rate, aux)
|
|
1427
|
+
|
|
1428
|
+
# use the exponential/Poisson process trick for small rate
|
|
1429
|
+
# use the normal approximation for large rate
|
|
1430
|
+
def _jax_wrapped_distribution_poisson_exponential(fls, nfls, params, key):
|
|
1431
|
+
rate, key, err, params = jax_rate(fls, nfls, params, key)
|
|
1432
|
+
key, subkey = random.split(key)
|
|
1433
|
+
sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
|
|
1434
|
+
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1435
|
+
err = err | (out_of_bounds * ERR)
|
|
1436
|
+
return sample, key, err, params
|
|
1437
|
+
return _jax_wrapped_distribution_poisson_exponential
|
|
1438
|
+
|
|
1439
|
+
def _jax_negative_binomial(self, expr, aux):
|
|
1440
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1441
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1442
|
+
arg_trials, arg_prob = expr.args
|
|
1443
|
+
|
|
1444
|
+
# if prob and trials is non-fluent, always use the exact operation
|
|
1445
|
+
if (not self.traced.cached_is_fluent(arg_trials) and
|
|
1446
|
+
not self.traced.cached_is_fluent(arg_prob)):
|
|
1447
|
+
return super()._jax_negative_binomial(expr, aux)
|
|
1448
|
+
|
|
1449
|
+
id_ = expr.id
|
|
1450
|
+
aux['params'][id_] = (self.poisson_comparison_weight, self.poisson_min_cdf)
|
|
1451
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1452
|
+
|
|
1453
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1454
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1455
|
+
|
|
1456
|
+
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
1457
|
+
def _jax_wrapped_distribution_negative_binomial_exponential(fls, nfls, params, key):
|
|
1458
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1459
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1460
|
+
key, subkey = random.split(key)
|
|
1111
1461
|
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1112
1462
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1463
|
+
gamma = random.gamma(key=subkey, a=trials, dtype=self.REAL)
|
|
1464
|
+
rate = ((1.0 - prob) / prob) * gamma
|
|
1465
|
+
sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
|
|
1466
|
+
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1467
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
|
|
1468
|
+
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1469
|
+
return sample, key, err, params
|
|
1470
|
+
return _jax_wrapped_distribution_negative_binomial_exponential
|
|
1471
|
+
|
|
1472
|
+
|
|
1473
|
+
class GumbelSoftmaxPoisson(JaxRDDLCompilerWithGrad):
|
|
1474
|
+
|
|
1475
|
+
def __init__(self, *args,
|
|
1476
|
+
poisson_nbins: int=100,
|
|
1477
|
+
poisson_softmax_weight: float=10.,
|
|
1478
|
+
poisson_min_cdf: float=0.999,
|
|
1479
|
+
poisson_eps: float=1e-14, **kwargs) -> None:
|
|
1480
|
+
super(GumbelSoftmaxPoisson, self).__init__(*args, **kwargs)
|
|
1481
|
+
self.poisson_nbins = poisson_nbins
|
|
1482
|
+
self.poisson_softmax_weight = float(poisson_softmax_weight)
|
|
1483
|
+
self.poisson_min_cdf = float(poisson_min_cdf)
|
|
1484
|
+
self.poisson_eps = float(poisson_eps)
|
|
1485
|
+
|
|
1486
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1487
|
+
kwargs = super().get_kwargs()
|
|
1488
|
+
kwargs['poisson_nbins'] = self.poisson_nbins
|
|
1489
|
+
kwargs['poisson_softmax_weight'] = self.poisson_softmax_weight
|
|
1490
|
+
kwargs['poisson_min_cdf'] = self.poisson_min_cdf
|
|
1491
|
+
kwargs['poisson_eps'] = self.poisson_eps
|
|
1492
|
+
return kwargs
|
|
1493
|
+
|
|
1494
|
+
def gumbel_softmax_poisson(self, key: random.PRNGKey,
|
|
1495
|
+
rate: jnp.ndarray, w: float, eps: float) -> jnp.ndarray:
|
|
1496
|
+
ks = jnp.arange(self.poisson_nbins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
1497
|
+
rate = rate[..., jnp.newaxis]
|
|
1498
|
+
log_prob = ks * jnp.log(rate + eps) - rate - scipy.special.gammaln(ks + 1)
|
|
1499
|
+
g = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=rate.dtype)
|
|
1500
|
+
return SoftmaxArgmax.soft_argmax(g + log_prob, w=w, axes=-1)
|
|
1501
|
+
|
|
1502
|
+
def branched_approx_to_poisson(self, key: random.PRNGKey,
|
|
1503
|
+
rate: jnp.ndarray,
|
|
1504
|
+
w: float, min_cdf: float, eps: float) -> jnp.ndarray:
|
|
1505
|
+
cuml_prob = scipy.stats.poisson.cdf(self.poisson_nbins, rate)
|
|
1506
|
+
z = random.normal(key=key, shape=jnp.shape(rate), dtype=rate.dtype)
|
|
1507
|
+
return jnp.where(
|
|
1508
|
+
jax.lax.stop_gradient(cuml_prob >= min_cdf),
|
|
1509
|
+
self.gumbel_softmax_poisson(key, rate, w, eps),
|
|
1510
|
+
rate + jnp.sqrt(rate) * z
|
|
1511
|
+
)
|
|
1512
|
+
|
|
1513
|
+
def _jax_poisson(self, expr, aux):
|
|
1514
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1515
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1516
|
+
arg_rate, = expr.args
|
|
1517
|
+
|
|
1518
|
+
# if rate is non-fluent, always use the exact operation
|
|
1519
|
+
if not self.traced.cached_is_fluent(arg_rate):
|
|
1520
|
+
return super()._jax_poisson(expr, aux)
|
|
1521
|
+
|
|
1522
|
+
id_ = expr.id
|
|
1523
|
+
aux['params'][id_] = (self.poisson_softmax_weight, self.poisson_min_cdf, self.poisson_eps)
|
|
1524
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1117
1525
|
|
|
1118
|
-
|
|
1119
|
-
class FuzzyLogic(Logic):
|
|
1120
|
-
'''A class representing fuzzy logic in JAX.'''
|
|
1121
|
-
|
|
1122
|
-
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
1123
|
-
complement: Complement=StandardComplement(),
|
|
1124
|
-
comparison: Comparison=SigmoidComparison(),
|
|
1125
|
-
sampling: RandomSampling=SoftRandomSampling(),
|
|
1126
|
-
rounding: Rounding=SoftRounding(),
|
|
1127
|
-
control: ControlFlow=SoftControlFlow(),
|
|
1128
|
-
eps: float=1e-15,
|
|
1129
|
-
use64bit: bool=False) -> None:
|
|
1130
|
-
'''Creates a new fuzzy logic in Jax.
|
|
1131
|
-
|
|
1132
|
-
:param tnorm: fuzzy operator for logical AND
|
|
1133
|
-
:param complement: fuzzy operator for logical NOT
|
|
1134
|
-
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
1135
|
-
:param sampling: random sampling of non-reparameterizable distributions
|
|
1136
|
-
:param rounding: rounding floating values to integers
|
|
1137
|
-
:param control: if and switch control structures
|
|
1138
|
-
:param eps: small positive float to mitigate underflow
|
|
1139
|
-
:param use64bit: whether to perform arithmetic in 64 bit
|
|
1140
|
-
'''
|
|
1141
|
-
super().__init__(use64bit=use64bit)
|
|
1142
|
-
self.tnorm = tnorm
|
|
1143
|
-
self.complement = complement
|
|
1144
|
-
self.comparison = comparison
|
|
1145
|
-
self.sampling = sampling
|
|
1146
|
-
self.rounding = rounding
|
|
1147
|
-
self.control = control
|
|
1148
|
-
self.eps = eps
|
|
1149
|
-
|
|
1150
|
-
def __str__(self) -> str:
|
|
1151
|
-
return (f'model relaxation:\n'
|
|
1152
|
-
f' tnorm ={str(self.tnorm)}\n'
|
|
1153
|
-
f' complement ={str(self.complement)}\n'
|
|
1154
|
-
f' comparison ={str(self.comparison)}\n'
|
|
1155
|
-
f' sampling ={str(self.sampling)}\n'
|
|
1156
|
-
f' rounding ={str(self.rounding)}\n'
|
|
1157
|
-
f' control ={str(self.control)}\n'
|
|
1158
|
-
f' underflow_tol={self.eps}\n'
|
|
1159
|
-
f' use_64_bit ={self.use64bit}\n')
|
|
1160
|
-
|
|
1161
|
-
def summarize_hyperparameters(self) -> str:
|
|
1162
|
-
return self.__str__()
|
|
1163
|
-
|
|
1164
|
-
# ===========================================================================
|
|
1165
|
-
# logical operators
|
|
1166
|
-
# ===========================================================================
|
|
1167
|
-
|
|
1168
|
-
def logical_and(self, id, init_params):
|
|
1169
|
-
return self.tnorm.norm(id, init_params)
|
|
1170
|
-
|
|
1171
|
-
def logical_not(self, id, init_params):
|
|
1172
|
-
return self.complement(id, init_params)
|
|
1173
|
-
|
|
1174
|
-
def logical_or(self, id, init_params):
|
|
1175
|
-
_not1 = self.complement(f'{id}_~1', init_params)
|
|
1176
|
-
_not2 = self.complement(f'{id}_~2', init_params)
|
|
1177
|
-
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
1178
|
-
_not = self.complement(f'{id}_~', init_params)
|
|
1179
|
-
|
|
1180
|
-
def _jax_wrapped_calc_or_approx(x, y, params):
|
|
1181
|
-
not_x, params = _not1(x, params)
|
|
1182
|
-
not_y, params = _not2(y, params)
|
|
1183
|
-
not_x_and_not_y, params = _and(not_x, not_y, params)
|
|
1184
|
-
return _not(not_x_and_not_y, params)
|
|
1185
|
-
return _jax_wrapped_calc_or_approx
|
|
1186
|
-
|
|
1187
|
-
def xor(self, id, init_params):
|
|
1188
|
-
_not = self.complement(f'{id}_~', init_params)
|
|
1189
|
-
_and1 = self.tnorm.norm(f'{id}_^1', init_params)
|
|
1190
|
-
_and2 = self.tnorm.norm(f'{id}_^2', init_params)
|
|
1191
|
-
_or = self.logical_or(f'{id}_|', init_params)
|
|
1192
|
-
|
|
1193
|
-
def _jax_wrapped_calc_xor_approx(x, y, params):
|
|
1194
|
-
x_and_y, params = _and1(x, y, params)
|
|
1195
|
-
not_x_and_y, params = _not(x_and_y, params)
|
|
1196
|
-
x_or_y, params = _or(x, y, params)
|
|
1197
|
-
return _and2(x_or_y, not_x_and_y, params)
|
|
1198
|
-
return _jax_wrapped_calc_xor_approx
|
|
1199
|
-
|
|
1200
|
-
def implies(self, id, init_params):
|
|
1201
|
-
_not = self.complement(f'{id}_~', init_params)
|
|
1202
|
-
_or = self.logical_or(f'{id}_|', init_params)
|
|
1203
|
-
|
|
1204
|
-
def _jax_wrapped_calc_implies_approx(x, y, params):
|
|
1205
|
-
not_x, params = _not(x, params)
|
|
1206
|
-
return _or(not_x, y, params)
|
|
1207
|
-
return _jax_wrapped_calc_implies_approx
|
|
1208
|
-
|
|
1209
|
-
def equiv(self, id, init_params):
|
|
1210
|
-
_implies1 = self.implies(f'{id}_=>1', init_params)
|
|
1211
|
-
_implies2 = self.implies(f'{id}_=>2', init_params)
|
|
1212
|
-
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
1213
|
-
|
|
1214
|
-
def _jax_wrapped_calc_equiv_approx(x, y, params):
|
|
1215
|
-
x_implies_y, params = _implies1(x, y, params)
|
|
1216
|
-
y_implies_x, params = _implies2(y, x, params)
|
|
1217
|
-
return _and(x_implies_y, y_implies_x, params)
|
|
1218
|
-
return _jax_wrapped_calc_equiv_approx
|
|
1219
|
-
|
|
1220
|
-
def forall(self, id, init_params):
|
|
1221
|
-
return self.tnorm.norms(id, init_params)
|
|
1222
|
-
|
|
1223
|
-
def exists(self, id, init_params):
|
|
1224
|
-
_not1 = self.complement(f'{id}_~1', init_params)
|
|
1225
|
-
_not2 = self.complement(f'{id}_~2', init_params)
|
|
1226
|
-
_forall = self.forall(f'{id}_forall', init_params)
|
|
1227
|
-
|
|
1228
|
-
def _jax_wrapped_calc_exists_approx(x, axis, params):
|
|
1229
|
-
not_x, params = _not1(x, params)
|
|
1230
|
-
forall_not_x, params = _forall(not_x, axis, params)
|
|
1231
|
-
return _not2(forall_not_x, params)
|
|
1232
|
-
return _jax_wrapped_calc_exists_approx
|
|
1233
|
-
|
|
1234
|
-
# ===========================================================================
|
|
1235
|
-
# comparison operators
|
|
1236
|
-
# ===========================================================================
|
|
1237
|
-
|
|
1238
|
-
def greater_equal(self, id, init_params):
|
|
1239
|
-
return self.comparison.greater_equal(id, init_params)
|
|
1240
|
-
|
|
1241
|
-
def greater(self, id, init_params):
|
|
1242
|
-
return self.comparison.greater(id, init_params)
|
|
1243
|
-
|
|
1244
|
-
def less_equal(self, id, init_params):
|
|
1245
|
-
_greater_eq = self.greater_equal(id, init_params)
|
|
1246
|
-
def _jax_wrapped_calc_leq_approx(x, y, params):
|
|
1247
|
-
return _greater_eq(-x, -y, params)
|
|
1248
|
-
return _jax_wrapped_calc_leq_approx
|
|
1249
|
-
|
|
1250
|
-
def less(self, id, init_params):
|
|
1251
|
-
_greater = self.greater(id, init_params)
|
|
1252
|
-
def _jax_wrapped_calc_less_approx(x, y, params):
|
|
1253
|
-
return _greater(-x, -y, params)
|
|
1254
|
-
return _jax_wrapped_calc_less_approx
|
|
1255
|
-
|
|
1256
|
-
def equal(self, id, init_params):
|
|
1257
|
-
return self.comparison.equal(id, init_params)
|
|
1258
|
-
|
|
1259
|
-
def not_equal(self, id, init_params):
|
|
1260
|
-
_not = self.complement(f'{id}_~', init_params)
|
|
1261
|
-
_equal = self.comparison.equal(f'{id}_==', init_params)
|
|
1262
|
-
def _jax_wrapped_calc_neq_approx(x, y, params):
|
|
1263
|
-
equal, params = _equal(x, y, params)
|
|
1264
|
-
return _not(equal, params)
|
|
1265
|
-
return _jax_wrapped_calc_neq_approx
|
|
1266
|
-
|
|
1267
|
-
# ===========================================================================
|
|
1268
|
-
# special functions
|
|
1269
|
-
# ===========================================================================
|
|
1270
|
-
|
|
1271
|
-
def sgn(self, id, init_params):
|
|
1272
|
-
return self.comparison.sgn(id, init_params)
|
|
1273
|
-
|
|
1274
|
-
def floor(self, id, init_params):
|
|
1275
|
-
return self.rounding.floor(id, init_params)
|
|
1526
|
+
jax_rate = self._jax(arg_rate, aux)
|
|
1276
1527
|
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
# ===========================================================================
|
|
1306
|
-
# indexing
|
|
1307
|
-
# ===========================================================================
|
|
1308
|
-
|
|
1309
|
-
def argmax(self, id, init_params):
|
|
1310
|
-
return self.comparison.argmax(id, init_params)
|
|
1311
|
-
|
|
1312
|
-
def argmin(self, id, init_params):
|
|
1313
|
-
_argmax = self.argmax(id, init_params)
|
|
1314
|
-
def _jax_wrapped_calc_argmin_approx(x, axis, param):
|
|
1315
|
-
return _argmax(-x, axis, param)
|
|
1316
|
-
return _jax_wrapped_calc_argmin_approx
|
|
1317
|
-
|
|
1318
|
-
# ===========================================================================
|
|
1319
|
-
# control flow
|
|
1320
|
-
# ===========================================================================
|
|
1321
|
-
|
|
1322
|
-
def control_if(self, id, init_params):
|
|
1323
|
-
return self.control.if_then_else(id, init_params)
|
|
1324
|
-
|
|
1325
|
-
def control_switch(self, id, init_params):
|
|
1326
|
-
return self.control.switch(id, init_params)
|
|
1327
|
-
|
|
1328
|
-
# ===========================================================================
|
|
1329
|
-
# random variables
|
|
1330
|
-
# ===========================================================================
|
|
1331
|
-
|
|
1332
|
-
def discrete(self, id, init_params):
|
|
1333
|
-
return self.sampling.discrete(id, init_params, self)
|
|
1334
|
-
|
|
1335
|
-
def bernoulli(self, id, init_params):
|
|
1336
|
-
return self.sampling.bernoulli(id, init_params, self)
|
|
1337
|
-
|
|
1338
|
-
def poisson(self, id, init_params):
|
|
1339
|
-
return self.sampling.poisson(id, init_params, self)
|
|
1340
|
-
|
|
1341
|
-
def geometric(self, id, init_params):
|
|
1342
|
-
return self.sampling.geometric(id, init_params, self)
|
|
1343
|
-
|
|
1344
|
-
def binomial(self, id, init_params):
|
|
1345
|
-
return self.sampling.binomial(id, init_params, self)
|
|
1346
|
-
|
|
1347
|
-
def negative_binomial(self, id, init_params):
|
|
1348
|
-
return self.sampling.negative_binomial(id, init_params, self)
|
|
1528
|
+
# use the gumbel-softmax and truncation trick for small rate
|
|
1529
|
+
# use the normal approximation for large rate
|
|
1530
|
+
def _jax_wrapped_distribution_poisson_gumbel_softmax(fls, nfls, params, key):
|
|
1531
|
+
rate, key, err, params = jax_rate(fls, nfls, params, key)
|
|
1532
|
+
key, subkey = random.split(key)
|
|
1533
|
+
sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
|
|
1534
|
+
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1535
|
+
err = err | (out_of_bounds * ERR)
|
|
1536
|
+
return sample, key, err, params
|
|
1537
|
+
return _jax_wrapped_distribution_poisson_gumbel_softmax
|
|
1538
|
+
|
|
1539
|
+
def _jax_negative_binomial(self, expr, aux):
|
|
1540
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1541
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1542
|
+
arg_trials, arg_prob = expr.args
|
|
1543
|
+
|
|
1544
|
+
# if prob and trials is non-fluent, always use the exact operation
|
|
1545
|
+
if (not self.traced.cached_is_fluent(arg_trials) and
|
|
1546
|
+
not self.traced.cached_is_fluent(arg_prob)):
|
|
1547
|
+
return super()._jax_negative_binomial(expr, aux)
|
|
1548
|
+
|
|
1549
|
+
id_ = expr.id
|
|
1550
|
+
aux['params'][id_] = (self.poisson_softmax_weight, self.poisson_min_cdf, self.poisson_eps)
|
|
1551
|
+
aux['overriden'][id_] = __class__.__name__
|
|
1552
|
+
|
|
1553
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1554
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1349
1555
|
|
|
1556
|
+
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
1557
|
+
def _jax_wrapped_distribution_negative_binomial_gumbel_softmax(fls, nfls, params, key):
|
|
1558
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1559
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1560
|
+
key, subkey = random.split(key)
|
|
1561
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1562
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1563
|
+
gamma = random.gamma(key=subkey, a=trials, dtype=self.REAL)
|
|
1564
|
+
rate = ((1.0 - prob) / prob) * gamma
|
|
1565
|
+
sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
|
|
1566
|
+
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1567
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
|
|
1568
|
+
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1569
|
+
return sample, key, err, params
|
|
1570
|
+
return _jax_wrapped_distribution_negative_binomial_gumbel_softmax
|
|
1350
1571
|
|
|
1351
|
-
# ===========================================================================
|
|
1352
|
-
# UNIT TESTS
|
|
1353
|
-
#
|
|
1354
|
-
# ===========================================================================
|
|
1355
|
-
|
|
1356
|
-
logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
|
|
1357
|
-
rounding=SoftRounding(10000.0),
|
|
1358
|
-
control=SoftControlFlow(10000.0))
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
def _test_logical():
|
|
1362
|
-
print('testing logical')
|
|
1363
|
-
init_params = {}
|
|
1364
|
-
_and = logic.logical_and(0, init_params)
|
|
1365
|
-
_not = logic.logical_not(1, init_params)
|
|
1366
|
-
_gre = logic.greater(2, init_params)
|
|
1367
|
-
_or = logic.logical_or(3, init_params)
|
|
1368
|
-
_if = logic.control_if(4, init_params)
|
|
1369
|
-
print(init_params)
|
|
1370
|
-
|
|
1371
|
-
# https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
|
|
1372
|
-
def test_logic(x1, x2, w):
|
|
1373
|
-
q1, w = _gre(x1, 0, w)
|
|
1374
|
-
q2, w = _gre(x2, 0, w)
|
|
1375
|
-
q3, w = _and(q1, q2, w)
|
|
1376
|
-
q4, w = _not(q1, w)
|
|
1377
|
-
q5, w = _not(q2, w)
|
|
1378
|
-
q6, w = _and(q4, q5, w)
|
|
1379
|
-
cond, w = _or(q3, q6, w)
|
|
1380
|
-
pred, w = _if(cond, +1, -1, w)
|
|
1381
|
-
return pred
|
|
1382
|
-
|
|
1383
|
-
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
|
|
1384
|
-
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
|
|
1385
|
-
print(test_logic(x1, x2, init_params))
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
def _test_indexing():
|
|
1389
|
-
print('testing indexing')
|
|
1390
|
-
init_params = {}
|
|
1391
|
-
_argmax = logic.argmax(0, init_params)
|
|
1392
|
-
_argmin = logic.argmin(1, init_params)
|
|
1393
|
-
print(init_params)
|
|
1394
|
-
|
|
1395
|
-
def argmaxmin(x, w):
|
|
1396
|
-
amax, w = _argmax(x, 0, w)
|
|
1397
|
-
amin, w = _argmin(x, 0, w)
|
|
1398
|
-
return amax, amin
|
|
1399
|
-
|
|
1400
|
-
values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
|
|
1401
|
-
amax, amin = argmaxmin(values, init_params)
|
|
1402
|
-
print(amax)
|
|
1403
|
-
print(amin)
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
def _test_control():
|
|
1407
|
-
print('testing control')
|
|
1408
|
-
init_params = {}
|
|
1409
|
-
_switch = logic.control_switch(0, init_params)
|
|
1410
|
-
print(init_params)
|
|
1411
|
-
|
|
1412
|
-
pred = jnp.asarray(jnp.linspace(0, 2, 10))
|
|
1413
|
-
case1 = jnp.asarray([-10.] * 10)
|
|
1414
|
-
case2 = jnp.asarray([1.5] * 10)
|
|
1415
|
-
case3 = jnp.asarray([10.] * 10)
|
|
1416
|
-
cases = jnp.asarray([case1, case2, case3])
|
|
1417
|
-
switch, _ = _switch(pred, cases, init_params)
|
|
1418
|
-
print(switch)
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
def _test_random():
|
|
1422
|
-
print('testing random')
|
|
1423
|
-
key = random.PRNGKey(42)
|
|
1424
|
-
init_params = {}
|
|
1425
|
-
_bernoulli = logic.bernoulli(0, init_params)
|
|
1426
|
-
_discrete = logic.discrete(1, init_params)
|
|
1427
|
-
_geometric = logic.geometric(2, init_params)
|
|
1428
|
-
print(init_params)
|
|
1429
|
-
|
|
1430
|
-
def bern(n, w):
|
|
1431
|
-
prob = jnp.asarray([0.3] * n)
|
|
1432
|
-
sample, _ = _bernoulli(key, prob, w)
|
|
1433
|
-
return sample
|
|
1434
|
-
|
|
1435
|
-
samples = bern(50000, init_params)
|
|
1436
|
-
print(jnp.mean(samples))
|
|
1437
|
-
|
|
1438
|
-
def disc(n, w):
|
|
1439
|
-
prob = jnp.asarray([0.1, 0.4, 0.5])
|
|
1440
|
-
prob = jnp.tile(prob, (n, 1))
|
|
1441
|
-
sample, _ = _discrete(key, prob, w)
|
|
1442
|
-
return sample
|
|
1443
|
-
|
|
1444
|
-
samples = disc(50000, init_params)
|
|
1445
|
-
samples = jnp.round(samples)
|
|
1446
|
-
print([jnp.mean(samples == i) for i in range(3)])
|
|
1447
|
-
|
|
1448
|
-
def geom(n, w):
|
|
1449
|
-
prob = jnp.asarray([0.3] * n)
|
|
1450
|
-
sample, _ = _geometric(key, prob, w)
|
|
1451
|
-
return sample
|
|
1452
|
-
|
|
1453
|
-
samples = geom(50000, init_params)
|
|
1454
|
-
print(jnp.mean(samples))
|
|
1455
|
-
|
|
1456
1572
|
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1573
|
+
class DeterminizedPoisson(JaxRDDLCompilerWithGrad):
|
|
1574
|
+
|
|
1575
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
1576
|
+
super(DeterminizedPoisson, self).__init__(*args, **kwargs)
|
|
1577
|
+
|
|
1578
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1579
|
+
return super().get_kwargs()
|
|
1580
|
+
|
|
1581
|
+
def _jax_poisson(self, expr, aux):
|
|
1582
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1583
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
|
|
1584
|
+
arg_rate, = expr.args
|
|
1585
|
+
|
|
1586
|
+
# if rate is non-fluent, always use the exact operation
|
|
1587
|
+
if not self.traced.cached_is_fluent(arg_rate):
|
|
1588
|
+
return super()._jax_poisson(expr, aux)
|
|
1589
|
+
|
|
1590
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1591
|
+
|
|
1592
|
+
jax_rate = self._jax(arg_rate, aux)
|
|
1593
|
+
|
|
1594
|
+
def _jax_wrapped_distribution_poisson_determinized(fls, nfls, params, key):
|
|
1595
|
+
rate, key, err, params = jax_rate(fls, nfls, params, key)
|
|
1596
|
+
sample = rate
|
|
1597
|
+
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1598
|
+
err = err | (out_of_bounds * ERR)
|
|
1599
|
+
return sample, key, err, params
|
|
1600
|
+
return _jax_wrapped_distribution_poisson_determinized
|
|
1601
|
+
|
|
1602
|
+
def _jax_negative_binomial(self, expr, aux):
|
|
1603
|
+
ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1604
|
+
JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
|
|
1605
|
+
arg_trials, arg_prob = expr.args
|
|
1606
|
+
|
|
1607
|
+
# if prob and trials is non-fluent, always use the exact operation
|
|
1608
|
+
if (not self.traced.cached_is_fluent(arg_trials) and
|
|
1609
|
+
not self.traced.cached_is_fluent(arg_prob)):
|
|
1610
|
+
return super()._jax_negative_binomial(expr, aux)
|
|
1611
|
+
|
|
1612
|
+
aux['overriden'][expr.id] = __class__.__name__
|
|
1613
|
+
|
|
1614
|
+
jax_trials = self._jax(arg_trials, aux)
|
|
1615
|
+
jax_prob = self._jax(arg_prob, aux)
|
|
1616
|
+
|
|
1617
|
+
def _jax_wrapped_distribution_negative_binomial_determinized(fls, nfls, params, key):
|
|
1618
|
+
trials, key, err2, params = jax_trials(fls, nfls, params, key)
|
|
1619
|
+
prob, key, err1, params = jax_prob(fls, nfls, params, key)
|
|
1620
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1621
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1622
|
+
sample = ((1.0 - prob) / prob) * trials
|
|
1623
|
+
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1624
|
+
jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
|
|
1625
|
+
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1626
|
+
return sample, key, err, params
|
|
1627
|
+
return _jax_wrapped_distribution_negative_binomial_determinized
|
|
1628
|
+
|
|
1629
|
+
|
|
1630
|
+
class DefaultJaxRDDLCompilerWithGrad(SigmoidRelational, SoftmaxArgmax,
|
|
1631
|
+
ProductNormLogical,
|
|
1632
|
+
SafeSqrt, SoftFloor, SoftRound,
|
|
1633
|
+
LinearIfElse, SoftmaxSwitch,
|
|
1634
|
+
ReparameterizedGeometric,
|
|
1635
|
+
ReparameterizedSigmoidBernoulli,
|
|
1636
|
+
GumbelSoftmaxDiscrete, GumbelSoftmaxBinomial,
|
|
1637
|
+
ExponentialPoisson):
|
|
1638
|
+
|
|
1639
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
1640
|
+
super(DefaultJaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
1641
|
+
|
|
1642
|
+
def get_kwargs(self) -> Dict[str, Any]:
|
|
1643
|
+
kwargs = {}
|
|
1644
|
+
for base in type(self).__bases__:
|
|
1645
|
+
if base.__name__ != 'object':
|
|
1646
|
+
kwargs = {**kwargs, **base.get_kwargs(self)}
|
|
1647
|
+
return kwargs
|