pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -2
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -29,26 +29,16 @@
29
29
  #
30
30
  # ***********************************************************************
31
31
 
32
+ import termcolor
33
+ from typing import Any, Dict, Optional, Set, Tuple, Union
32
34
 
33
- from abc import ABCMeta, abstractmethod
34
- import traceback
35
- from typing import Callable, Dict, Tuple, Union
36
-
35
+ import numpy as np
37
36
  import jax
38
37
  import jax.numpy as jnp
39
38
  import jax.random as random
40
39
  import jax.scipy as scipy
41
40
 
42
- from pyRDDLGym.core.debug.exception import raise_warning
43
-
44
- # more robust approach - if user does not have this or broken try to continue
45
- try:
46
- from tensorflow_probability.substrates import jax as tfp
47
- except Exception:
48
- raise_warning('Failed to import tensorflow-probability: '
49
- 'compilation of some probability distributions will fail.', 'red')
50
- traceback.print_exc()
51
- tfp = None
41
+ from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
52
42
 
53
43
 
54
44
  def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
@@ -59,1421 +49,1599 @@ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32)
59
49
  return literals
60
50
 
61
51
 
62
- # ===========================================================================
63
- # RELATIONAL OPERATIONS
64
- # - abstract class
65
- # - sigmoid comparison
66
- #
67
- # ===========================================================================
68
-
69
- class Comparison(metaclass=ABCMeta):
70
- '''Base class for approximate comparison operations.'''
71
-
72
- @abstractmethod
73
- def greater_equal(self, id, init_params):
74
- pass
75
-
76
- @abstractmethod
77
- def greater(self, id, init_params):
78
- pass
79
-
80
- @abstractmethod
81
- def equal(self, id, init_params):
82
- pass
83
-
84
- @abstractmethod
85
- def sgn(self, id, init_params):
86
- pass
87
-
88
- @abstractmethod
89
- def argmax(self, id, init_params):
90
- pass
91
-
92
-
93
- class SigmoidComparison(Comparison):
52
+ # branching sigmoid to help reduce numerical issues
53
+ @jax.custom_jvp
54
+ def stable_sigmoid(x: jnp.ndarray) -> jnp.ndarray:
55
+ return jnp.where(x >= 0, 1.0 / (1.0 + jnp.exp(-x)), jnp.exp(x) / (1.0 + jnp.exp(x)))
56
+
57
+
58
+ @stable_sigmoid.defjvp
59
+ def stable_sigmoid_jvp(primals, tangents):
60
+ (x,), (x_dot,) = primals, tangents
61
+ s = stable_sigmoid(x)
62
+ primal_out = s
63
+ tangent_out = x_dot * s * (1.0 - s)
64
+ return primal_out, tangent_out
65
+
66
+
67
+ # branching tanh to help reduce numerical issues
68
+ @jax.custom_jvp
69
+ def stable_tanh(x: jnp.ndarray) -> jnp.ndarray:
70
+ ax = jnp.abs(x)
71
+ small = jnp.where(
72
+ ax < 20.0,
73
+ jnp.expm1(2.0 * ax) / (jnp.expm1(2.0 * ax) + 2.0),
74
+ 1.0 - 2.0 * jnp.exp(-2.0 * ax)
75
+ )
76
+ return jnp.sign(x) * small
77
+
78
+
79
+ @stable_tanh.defjvp
80
+ def stable_tanh_jvp(primals, tangents):
81
+ (x,), (x_dot,) = primals, tangents
82
+ t = stable_tanh(x)
83
+ tangent_out = x_dot * (1.0 - t * t)
84
+ return t, tangent_out
85
+
86
+
87
+ # it seems JAX uses the stability trick already
88
+ def stable_softmax_weight_sum(logits: jnp.ndarray,
89
+ values: jnp.ndarray,
90
+ axis: Union[int, Tuple[int, ...]]) -> jnp.ndarray:
91
+ return jnp.sum(values * jax.nn.softmax(logits), axis=axis)
92
+
93
+
94
+ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
95
+ '''Compiles a RDDL AST representation to an equivalent JAX representation.
96
+ Unlike its parent class, this class treats all fluents as real-valued, and
97
+ replaces all mathematical operations by equivalent ones with a well defined
98
+ (e.g. non-zero) gradient where appropriate.
99
+ '''
100
+
101
+ def __init__(self, *args,
102
+ cpfs_without_grad: Optional[Set[str]]=None,
103
+ print_warnings: bool=True,
104
+ **kwargs) -> None:
105
+ '''Creates a new RDDL to Jax compiler, where operations that are not
106
+ differentiable are converted to approximate forms that have defined gradients.
107
+
108
+ :param *args: arguments to pass to base compiler
109
+ :param cpfs_without_grad: which CPFs do not have gradients (use straight
110
+ through gradient trick)
111
+ :param print_warnings: whether to print warnings
112
+ :param *kwargs: keyword arguments to pass to base compiler
113
+ '''
114
+ super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
115
+
116
+ if cpfs_without_grad is None:
117
+ cpfs_without_grad = set()
118
+ self.cpfs_without_grad = cpfs_without_grad
119
+ self.print_warnings = print_warnings
120
+
121
+ # actions and CPFs must be continuous
122
+ pvars_cast = set()
123
+ for (var, values) in self.init_values.items():
124
+ self.init_values[var] = np.asarray(values, dtype=self.REAL)
125
+ if not np.issubdtype(np.result_type(values), np.floating):
126
+ pvars_cast.add(var)
127
+ if self.print_warnings and pvars_cast:
128
+ print(termcolor.colored(
129
+ f'[INFO] Compiler will cast pvars {pvars_cast} to float.', 'dark_grey'))
130
+
131
+ def get_kwargs(self) -> Dict[str, Any]:
132
+ kwargs = super().get_kwargs()
133
+ kwargs['cpfs_without_grad'] = self.cpfs_without_grad
134
+ kwargs['print_warnings'] = self.print_warnings
135
+ return kwargs
136
+
137
+ def _jax_stop_grad(self, jax_expr):
138
+ def _jax_wrapped_stop_grad(fls, nfls, params, key):
139
+ sample, key, error, params = jax_expr(fls, nfls, params, key)
140
+ sample = jax.lax.stop_gradient(sample)
141
+ return sample, key, error, params
142
+ return _jax_wrapped_stop_grad
143
+
144
+ def _compile_cpfs(self, aux):
145
+
146
+ # cpfs will all be cast to float
147
+ cpfs_cast = set()
148
+ jax_cpfs = {}
149
+ for (_, cpfs) in self.levels.items():
150
+ for cpf in cpfs:
151
+ _, expr = self.rddl.cpfs[cpf]
152
+ jax_cpfs[cpf] = self._jax(expr, aux, dtype=self.REAL)
153
+ if self.rddl.variable_ranges[cpf] != 'real':
154
+ cpfs_cast.add(cpf)
155
+ if cpf in self.cpfs_without_grad:
156
+ jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
157
+
158
+ if self.print_warnings and cpfs_cast:
159
+ print(termcolor.colored(
160
+ f'[INFO] Compiler will cast CPFs {cpfs_cast} to float.', 'dark_grey'))
161
+ if self.print_warnings and self.cpfs_without_grad:
162
+ print(termcolor.colored(
163
+ f'[INFO] Gradient disabled for CPFs {self.cpfs_without_grad}.', 'dark_grey'))
164
+
165
+ return jax_cpfs
166
+
167
+ def _jax_unary_with_param(self, jax_expr, jax_op):
168
+ def _jax_wrapped_unary_op_with_param(fls, nfls, params, key):
169
+ sample, key, err, params = jax_expr(fls, nfls, params, key)
170
+ sample = self.ONE * sample
171
+ sample, params = jax_op(sample, params)
172
+ return sample, key, err, params
173
+ return _jax_wrapped_unary_op_with_param
174
+
175
+ def _jax_binary_with_param(self, jax_lhs, jax_rhs, jax_op):
176
+ def _jax_wrapped_binary_op_with_param(fls, nfls, params, key):
177
+ sample1, key, err1, params = jax_lhs(fls, nfls, params, key)
178
+ sample2, key, err2, params = jax_rhs(fls, nfls, params, key)
179
+ sample1 = self.ONE * sample1
180
+ sample2 = self.ONE * sample2
181
+ sample, params = jax_op(sample1, sample2, params)
182
+ err = err1 | err2
183
+ return sample, key, err, params
184
+ return _jax_wrapped_binary_op_with_param
185
+
186
+ def _jax_unary_helper_with_param(self, expr, aux, jax_op):
187
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
188
+ arg, = expr.args
189
+ jax_arg = self._jax(arg, aux)
190
+ return self._jax_unary_with_param(jax_arg, jax_op)
191
+
192
+ def _jax_binary_helper_with_param(self, expr, aux, jax_op):
193
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
194
+ lhs, rhs = expr.args
195
+ jax_lhs = self._jax(lhs, aux)
196
+ jax_rhs = self._jax(rhs, aux)
197
+ return self._jax_binary_with_param(jax_lhs, jax_rhs, jax_op)
198
+
199
+ def _jax_kron(self, expr, aux):
200
+ aux['overriden'][expr.id] = __class__.__name__
201
+ arg, = expr.args
202
+ arg = self._jax(arg, aux)
203
+ return arg
204
+
205
+
206
+ # ===============================================================================
207
+ # relational relaxations
208
+ # ===============================================================================
209
+
210
+ # https://arxiv.org/abs/2110.05651
211
+ class SigmoidRelational(JaxRDDLCompilerWithGrad):
94
212
  '''Comparison operations approximated using sigmoid functions.'''
95
213
 
96
- def __init__(self, weight: float=10.0) -> None:
97
- self.weight = float(weight)
98
-
99
- # https://arxiv.org/abs/2110.05651
100
- def greater_equal(self, id, init_params):
101
- id_ = str(id)
102
- init_params[id_] = self.weight
103
- def _jax_wrapped_calc_greater_equal_approx(x, y, params):
104
- gre_eq = jax.nn.sigmoid(params[id_] * (x - y))
105
- return gre_eq, params
106
- return _jax_wrapped_calc_greater_equal_approx
107
-
108
- def greater(self, id, init_params):
109
- return self.greater_equal(id, init_params)
110
-
111
- def equal(self, id, init_params):
112
- id_ = str(id)
113
- init_params[id_] = self.weight
114
- def _jax_wrapped_calc_equal_approx(x, y, params):
115
- equal = 1.0 - jnp.square(jnp.tanh(params[id_] * (y - x)))
116
- return equal, params
117
- return _jax_wrapped_calc_equal_approx
118
-
119
- def sgn(self, id, init_params):
120
- id_ = str(id)
121
- init_params[id_] = self.weight
122
- def _jax_wrapped_calc_sgn_approx(x, params):
123
- sgn = jnp.tanh(params[id_] * x)
124
- return sgn, params
125
- return _jax_wrapped_calc_sgn_approx
126
-
127
- # https://arxiv.org/abs/2110.05651
128
- def argmax(self, id, init_params):
129
- id_ = str(id)
130
- init_params[id_] = self.weight
131
- def _jax_wrapped_calc_argmax_approx(x, axis, params):
132
- literals = enumerate_literals(jnp.shape(x), axis=axis)
133
- softmax = jax.nn.softmax(params[id_] * x, axis=axis)
134
- sample = jnp.sum(literals * softmax, axis=axis)
214
+ def __init__(self, *args, sigmoid_weight: float=10.,
215
+ use_sigmoid_ste: bool=True, use_tanh_ste: bool=True,
216
+ **kwargs) -> None:
217
+ super(SigmoidRelational, self).__init__(*args, **kwargs)
218
+ self.sigmoid_weight = float(sigmoid_weight)
219
+ self.use_sigmoid_ste = use_sigmoid_ste
220
+ self.use_tanh_ste = use_tanh_ste
221
+
222
+ def get_kwargs(self) -> Dict[str, Any]:
223
+ kwargs = super().get_kwargs()
224
+ kwargs['sigmoid_weight'] = self.sigmoid_weight
225
+ kwargs['use_sigmoid_ste'] = self.use_sigmoid_ste
226
+ kwargs['use_tanh_ste'] = self.use_tanh_ste
227
+ return kwargs
228
+
229
+ def _jax_greater(self, expr, aux):
230
+ if not self.traced.cached_is_fluent(expr):
231
+ return super()._jax_greater(expr, aux)
232
+ id_ = expr.id
233
+ aux['params'][id_] = self.sigmoid_weight
234
+ aux['overriden'][id_] = __class__.__name__
235
+ def greater_op(x, y, params):
236
+ sample = stable_sigmoid(params[id_] * (x - y))
237
+ if self.use_sigmoid_ste:
238
+ sample = sample + jax.lax.stop_gradient(jnp.greater(x, y) - sample)
239
+ return sample, params
240
+ return self._jax_binary_helper_with_param(expr, aux, greater_op)
241
+
242
+ def _jax_greater_equal(self, expr, aux):
243
+ if not self.traced.cached_is_fluent(expr):
244
+ return super()._jax_greater(expr, aux)
245
+ id_ = expr.id
246
+ aux['params'][id_] = self.sigmoid_weight
247
+ aux['overriden'][id_] = __class__.__name__
248
+ def greater_equal_op(x, y, params):
249
+ sample = stable_sigmoid(params[id_] * (x - y))
250
+ if self.use_sigmoid_ste:
251
+ sample = sample + jax.lax.stop_gradient(jnp.greater_equal(x, y) - sample)
135
252
  return sample, params
136
- return _jax_wrapped_calc_argmax_approx
253
+ return self._jax_binary_helper_with_param(expr, aux, greater_equal_op)
254
+
255
+ def _jax_less(self, expr, aux):
256
+ if not self.traced.cached_is_fluent(expr):
257
+ return super()._jax_less(expr, aux)
258
+ id_ = expr.id
259
+ aux['params'][id_] = self.sigmoid_weight
260
+ aux['overriden'][id_] = __class__.__name__
261
+ def less_op(x, y, params):
262
+ sample = stable_sigmoid(params[id_] * (y - x))
263
+ if self.use_sigmoid_ste:
264
+ sample = sample + jax.lax.stop_gradient(jnp.less(x, y) - sample)
265
+ return sample, params
266
+ return self._jax_binary_helper_with_param(expr, aux, less_op)
267
+
268
+ def _jax_less_equal(self, expr, aux):
269
+ if not self.traced.cached_is_fluent(expr):
270
+ return super()._jax_less(expr, aux)
271
+ id_ = expr.id
272
+ aux['params'][id_] = self.sigmoid_weight
273
+ aux['overriden'][id_] = __class__.__name__
274
+ def less_equal_op(x, y, params):
275
+ sample = stable_sigmoid(params[id_] * (y - x))
276
+ if self.use_sigmoid_ste:
277
+ sample = sample + jax.lax.stop_gradient(jnp.less_equal(x, y) - sample)
278
+ return sample, params
279
+ return self._jax_binary_helper_with_param(expr, aux, less_equal_op)
280
+
281
+ def _jax_equal(self, expr, aux):
282
+ if not self.traced.cached_is_fluent(expr):
283
+ return super()._jax_equal(expr, aux)
284
+ id_ = expr.id
285
+ aux['params'][id_] = self.sigmoid_weight
286
+ aux['overriden'][id_] = __class__.__name__
287
+ def equal_op(x, y, params):
288
+ sample = 1. - jnp.square(stable_tanh(params[id_] * (y - x)))
289
+ if self.use_tanh_ste:
290
+ sample = sample + jax.lax.stop_gradient(jnp.equal(x, y) - sample)
291
+ return sample, params
292
+ return self._jax_binary_helper_with_param(expr, aux, equal_op)
293
+
294
+ def _jax_not_equal(self, expr, aux):
295
+ if not self.traced.cached_is_fluent(expr):
296
+ return super()._jax_not_equal(expr, aux)
297
+ id_ = expr.id
298
+ aux['params'][id_] = self.sigmoid_weight
299
+ aux['overriden'][id_] = __class__.__name__
300
+ def not_equal_op(x, y, params):
301
+ sample = jnp.square(stable_tanh(params[id_] * (y - x)))
302
+ if self.use_tanh_ste:
303
+ sample = sample + jax.lax.stop_gradient(jnp.not_equal(x, y) - sample)
304
+ return sample, params
305
+ return self._jax_binary_helper_with_param(expr, aux, not_equal_op)
306
+
307
+ def _jax_sgn(self, expr, aux):
308
+ if not self.traced.cached_is_fluent(expr):
309
+ return super()._jax_sgn(expr, aux)
310
+ id_ = expr.id
311
+ aux['params'][id_] = self.sigmoid_weight
312
+ aux['overriden'][id_] = __class__.__name__
313
+ def sgn_op(x, params):
314
+ sample = stable_tanh(params[id_] * x)
315
+ if self.use_tanh_ste:
316
+ sample = sample + jax.lax.stop_gradient(jnp.sign(x) - sample)
317
+ return sample, params
318
+ return self._jax_unary_helper_with_param(expr, aux, sgn_op)
137
319
 
138
- def __str__(self) -> str:
139
- return f'Sigmoid comparison with weight {self.weight}'
140
320
 
321
+ class SoftmaxArgmax(JaxRDDLCompilerWithGrad):
322
+ '''Argmin/argmax operations approximated using softmax functions.'''
141
323
 
142
- # ===========================================================================
143
- # ROUNDING OPERATIONS
144
- # - abstract class
145
- # - soft rounding
146
- #
147
- # ===========================================================================
148
-
149
- class Rounding(metaclass=ABCMeta):
150
- '''Base class for approximate rounding operations.'''
151
-
152
- @abstractmethod
153
- def floor(self, id, init_params):
154
- pass
324
+ def __init__(self, *args, argmax_weight: float=10., **kwargs) -> None:
325
+ super(SoftmaxArgmax, self).__init__(*args, **kwargs)
326
+ self.argmax_weight = float(argmax_weight)
155
327
 
156
- @abstractmethod
157
- def round(self, id, init_params):
158
- pass
328
+ def get_kwargs(self) -> Dict[str, Any]:
329
+ kwargs = super().get_kwargs()
330
+ kwargs['argmax_weight'] = self.argmax_weight
331
+ return kwargs
332
+
333
+ @staticmethod
334
+ def soft_argmax(x: jnp.ndarray, w: float, axes: Union[int, Tuple[int, ...]]) -> jnp.ndarray:
335
+ literals = enumerate_literals(jnp.shape(x), axis=axes)
336
+ return stable_softmax_weight_sum(w * x, literals, axis=axes)
337
+
338
+ def _jax_argmax(self, expr, aux):
339
+ if not self.traced.cached_is_fluent(expr):
340
+ return super()._jax_argmax(expr, aux)
341
+ id_ = expr.id
342
+ aux['params'][id_] = self.argmax_weight
343
+ aux['overriden'][id_] = __class__.__name__
344
+ arg = expr.args[-1]
345
+ _, axes = self.traced.cached_sim_info(expr)
346
+ jax_expr = self._jax(arg, aux)
347
+ def argmax_op(x, params):
348
+ sample = self.soft_argmax(x, params[id_], axes)
349
+ return sample, params
350
+ return self._jax_unary_with_param(jax_expr, argmax_op)
351
+
352
+ def _jax_argmin(self, expr, aux):
353
+ if not self.traced.cached_is_fluent(expr):
354
+ return super()._jax_argmin(expr, aux)
355
+ id_ = expr.id
356
+ aux['params'][id_] = self.argmax_weight
357
+ aux['overriden'][id_] = __class__.__name__
358
+ arg = expr.args[-1]
359
+ _, axes = self.traced.cached_sim_info(expr)
360
+ jax_expr = self._jax(arg, aux)
361
+ def argmin_op(x, params):
362
+ sample = self.soft_argmax(-x, params[id_], axes)
363
+ return sample, params
364
+ return self._jax_unary_with_param(jax_expr, argmin_op)
365
+
159
366
 
367
+ # ===============================================================================
368
+ # logical relaxations
369
+ # ===============================================================================
160
370
 
161
- class SoftRounding(Rounding):
162
- '''Rounding operations approximated using soft operations.'''
371
+ class ProductNormLogical(JaxRDDLCompilerWithGrad):
372
+ '''Product t-norm given by the expression (x, y) -> x * y.'''
163
373
 
164
- def __init__(self, weight: float=10.0) -> None:
165
- self.weight = float(weight)
166
-
167
- # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
168
- def floor(self, id, init_params):
169
- id_ = str(id)
170
- init_params[id_] = self.weight
171
- def _jax_wrapped_calc_floor_approx(x, params):
172
- param = params[id_]
173
- denom = jnp.tanh(param / 4.0)
174
- floor = (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
175
- jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
176
- return floor, params
177
- return _jax_wrapped_calc_floor_approx
374
+ def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
375
+ super(ProductNormLogical, self).__init__(*args, **kwargs)
376
+ self.use_logic_ste = use_logic_ste
377
+
378
+ def get_kwargs(self) -> Dict[str, Any]:
379
+ kwargs = super().get_kwargs()
380
+ kwargs['use_logic_ste'] = self.use_logic_ste
381
+ return kwargs
382
+
383
+ def _jax_not(self, expr, aux):
384
+ if not self.traced.cached_is_fluent(expr):
385
+ return super()._jax_not(expr, aux)
386
+ aux['overriden'][expr.id] = __class__.__name__
387
+ def not_op(x):
388
+ sample = 1. - x
389
+ if self.use_logic_ste:
390
+ hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
391
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
392
+ return sample
393
+ return self._jax_unary_helper(expr, aux, not_op)
394
+
395
+ def _jax_and(self, expr, aux):
396
+ if not self.traced.cached_is_fluent(expr):
397
+ return super()._jax_and(expr, aux)
398
+ aux['overriden'][expr.id] = __class__.__name__
399
+ def and_op(x, y):
400
+ sample = jnp.multiply(x, y)
401
+ if self.use_logic_ste:
402
+ hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
403
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
404
+ return sample
405
+ return self._jax_nary_helper(expr, aux, and_op)
406
+
407
+ def _jax_or(self, expr, aux):
408
+ if not self.traced.cached_is_fluent(expr):
409
+ return super()._jax_or(expr, aux)
410
+ aux['overriden'][expr.id] = __class__.__name__
411
+ def or_op(x, y):
412
+ sample = 1. - (1. - x) * (1. - y)
413
+ if self.use_logic_ste:
414
+ hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
415
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
416
+ return sample
417
+ return self._jax_nary_helper(expr, aux, or_op)
418
+
419
+ def _jax_xor(self, expr, aux):
420
+ if not self.traced.cached_is_fluent(expr):
421
+ return super()._jax_xor(expr, aux)
422
+ aux['overriden'][expr.id] = __class__.__name__
423
+ def xor_op(x, y):
424
+ sample = (1. - (1. - x) * (1. - y)) * (1. - x * y)
425
+ if self.use_logic_ste:
426
+ hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
427
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
428
+ return sample
429
+ return self._jax_binary_helper(expr, aux, xor_op)
430
+
431
+ def _jax_implies(self, expr, aux):
432
+ if not self.traced.cached_is_fluent(expr):
433
+ return super()._jax_implies(expr, aux)
434
+ aux['overriden'][expr.id] = __class__.__name__
435
+ def implies_op(x, y):
436
+ sample = 1. - x * (1. - y)
437
+ if self.use_logic_ste:
438
+ hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
439
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
440
+ return sample
441
+ return self._jax_binary_helper(expr, aux, implies_op)
442
+
443
+ def _jax_equiv(self, expr, aux):
444
+ if not self.traced.cached_is_fluent(expr):
445
+ return super()._jax_equiv(expr, aux)
446
+ aux['overriden'][expr.id] = __class__.__name__
447
+ def equiv_op(x, y):
448
+ sample = (1. - x * (1. - y)) * (1. - y * (1. - x))
449
+ if self.use_logic_ste:
450
+ hard_sample = jnp.logical_and(
451
+ jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
452
+ hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
453
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
454
+ return sample
455
+ return self._jax_binary_helper(expr, aux, equiv_op)
456
+
457
+ def _jax_forall(self, expr, aux):
458
+ if not self.traced.cached_is_fluent(expr):
459
+ return super()._jax_forall(expr, aux)
460
+ aux['overriden'][expr.id] = __class__.__name__
461
+ def forall_op(x, axis):
462
+ sample = jnp.prod(x, axis=axis)
463
+ if self.use_logic_ste:
464
+ hard_sample = jnp.all(x, axis=axis)
465
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
466
+ return sample
467
+ return self._jax_aggregation_helper(expr, aux, forall_op)
468
+
469
+ def _jax_exists(self, expr, aux):
470
+ if not self.traced.cached_is_fluent(expr):
471
+ return super()._jax_exists(expr, aux)
472
+ aux['overriden'][expr.id] = __class__.__name__
473
+ def exists_op(x, axis):
474
+ sample = 1. - jnp.prod(1. - x, axis=axis)
475
+ if self.use_logic_ste:
476
+ hard_sample = jnp.any(x, axis=axis)
477
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
478
+ return sample
479
+ return self._jax_aggregation_helper(expr, aux, exists_op)
480
+
481
+
482
+ class GodelNormLogical(JaxRDDLCompilerWithGrad):
483
+ '''Godel t-norm given by the expression (x, y) -> min(x, y).'''
178
484
 
179
- # https://arxiv.org/abs/2006.09952
180
- def round(self, id, init_params):
181
- id_ = str(id)
182
- init_params[id_] = self.weight
183
- def _jax_wrapped_calc_round_approx(x, params):
485
+ def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
486
+ super(GodelNormLogical, self).__init__(*args, **kwargs)
487
+ self.use_logic_ste = use_logic_ste
488
+
489
+ def get_kwargs(self) -> Dict[str, Any]:
490
+ kwargs = super().get_kwargs()
491
+ kwargs['use_logic_ste'] = self.use_logic_ste
492
+ return kwargs
493
+
494
+ def _jax_not(self, expr, aux):
495
+ if not self.traced.cached_is_fluent(expr):
496
+ return super()._jax_not(expr, aux)
497
+ aux['overriden'][expr.id] = __class__.__name__
498
+ def not_op(x):
499
+ sample = 1. - x
500
+ if self.use_logic_ste:
501
+ hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
502
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
503
+ return sample
504
+ return self._jax_unary_helper(expr, aux, not_op)
505
+
506
+ def _jax_and(self, expr, aux):
507
+ if not self.traced.cached_is_fluent(expr):
508
+ return super()._jax_and(expr, aux)
509
+ aux['overriden'][expr.id] = __class__.__name__
510
+ def and_op(x, y):
511
+ sample = jnp.minimum(x, y)
512
+ if self.use_logic_ste:
513
+ hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
514
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
515
+ return sample
516
+ return self._jax_nary_helper(expr, aux, and_op)
517
+
518
+ def _jax_or(self, expr, aux):
519
+ if not self.traced.cached_is_fluent(expr):
520
+ return super()._jax_or(expr, aux)
521
+ aux['overriden'][expr.id] = __class__.__name__
522
+ def or_op(x, y):
523
+ sample = jnp.maximum(x, y)
524
+ if self.use_logic_ste:
525
+ hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
526
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
527
+ return sample
528
+ return self._jax_nary_helper(expr, aux, or_op)
529
+
530
+ def _jax_xor(self, expr, aux):
531
+ if not self.traced.cached_is_fluent(expr):
532
+ return super()._jax_xor(expr, aux)
533
+ aux['overriden'][expr.id] = __class__.__name__
534
+ def xor_op(x, y):
535
+ sample = jnp.minimum(jnp.maximum(x, y), 1. - jnp.minimum(x, y))
536
+ if self.use_logic_ste:
537
+ hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
538
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
539
+ return sample
540
+ return self._jax_binary_helper(expr, aux, xor_op)
541
+
542
+ def _jax_implies(self, expr, aux):
543
+ if not self.traced.cached_is_fluent(expr):
544
+ return super()._jax_implies(expr, aux)
545
+ aux['overriden'][expr.id] = __class__.__name__
546
+ def implies_op(x, y):
547
+ sample = jnp.maximum(1. - x, y)
548
+ if self.use_logic_ste:
549
+ hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
550
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
551
+ return sample
552
+ return self._jax_binary_helper(expr, aux, implies_op)
553
+
554
+ def _jax_equiv(self, expr, aux):
555
+ if not self.traced.cached_is_fluent(expr):
556
+ return super()._jax_equiv(expr, aux)
557
+ aux['overriden'][expr.id] = __class__.__name__
558
+ def equiv_op(x, y):
559
+ sample = jnp.minimum(jnp.maximum(1. - x, y), jnp.maximum(1. - y, x))
560
+ if self.use_logic_ste:
561
+ hard_sample = jnp.logical_and(
562
+ jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
563
+ hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
564
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
565
+ return sample
566
+ return self._jax_binary_helper(expr, aux, equiv_op)
567
+
568
+ def _jax_forall(self, expr, aux):
569
+ if not self.traced.cached_is_fluent(expr):
570
+ return super()._jax_forall(expr, aux)
571
+ aux['overriden'][expr.id] = __class__.__name__
572
+ def all_op(x, axis):
573
+ sample = jnp.min(x, axis=axis)
574
+ if self.use_logic_ste:
575
+ hard_sample = jnp.all(x, axis=axis)
576
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
577
+ return sample
578
+ return self._jax_aggregation_helper(expr, aux, all_op)
579
+
580
+ def _jax_exists(self, expr, aux):
581
+ if not self.traced.cached_is_fluent(expr):
582
+ return super()._jax_exists(expr, aux)
583
+ aux['overriden'][expr.id] = __class__.__name__
584
+ def exists_op(x, axis):
585
+ sample = jnp.max(x, axis=axis)
586
+ if self.use_logic_ste:
587
+ hard_sample = jnp.any(x, axis=axis)
588
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
589
+ return sample
590
+ return self._jax_aggregation_helper(expr, aux, exists_op)
591
+
592
+
593
+ class LukasiewiczNormLogical(JaxRDDLCompilerWithGrad):
594
+ '''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
595
+
596
+ def __init__(self, *args, use_logic_ste: bool=False, **kwargs) -> None:
597
+ super(LukasiewiczNormLogical, self).__init__(*args, **kwargs)
598
+ self.use_logic_ste = use_logic_ste
599
+
600
+ def get_kwargs(self) -> Dict[str, Any]:
601
+ kwargs = super().get_kwargs()
602
+ kwargs['use_logic_ste'] = self.use_logic_ste
603
+ return kwargs
604
+
605
+ def _jax_not(self, expr, aux):
606
+ if not self.traced.cached_is_fluent(expr):
607
+ return super()._jax_not(expr, aux)
608
+ aux['overriden'][expr.id] = __class__.__name__
609
+ def not_op(x):
610
+ sample = 1. - x
611
+ if self.use_logic_ste:
612
+ hard_sample = jnp.asarray(x <= 0.5, dtype=self.REAL)
613
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
614
+ return sample
615
+ return self._jax_unary_helper(expr, aux, not_op)
616
+
617
+ def _jax_and(self, expr, aux):
618
+ if not self.traced.cached_is_fluent(expr):
619
+ return super()._jax_and(expr, aux)
620
+ aux['overriden'][expr.id] = __class__.__name__
621
+ def and_op(x, y):
622
+ sample = jax.nn.relu(x + y - 1.)
623
+ if self.use_logic_ste:
624
+ hard_sample = jnp.asarray(jnp.logical_and(x > 0.5, y > 0.5), dtype=self.REAL)
625
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
626
+ return sample
627
+ return self._jax_nary_helper(expr, aux, and_op)
628
+
629
+ def _jax_or(self, expr, aux):
630
+ if not self.traced.cached_is_fluent(expr):
631
+ return super()._jax_or(expr, aux)
632
+ aux['overriden'][expr.id] = __class__.__name__
633
+ def or_op(x, y):
634
+ sample = 1. - jax.nn.relu(1. - x - y)
635
+ if self.use_logic_ste:
636
+ hard_sample = jnp.asarray(jnp.logical_or(x > 0.5, y > 0.5), dtype=self.REAL)
637
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
638
+ return sample
639
+ return self._jax_nary_helper(expr, aux, or_op)
640
+
641
+ def _jax_xor(self, expr, aux):
642
+ if not self.traced.cached_is_fluent(expr):
643
+ return super()._jax_xor(expr, aux)
644
+ aux['overriden'][expr.id] = __class__.__name__
645
+ def xor_op(x, y):
646
+ sample = jax.nn.relu(1. - jnp.abs(1. - x - y))
647
+ if self.use_logic_ste:
648
+ hard_sample = jnp.asarray(jnp.logical_xor(x > 0.5, y > 0.5), dtype=self.REAL)
649
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
650
+ return sample
651
+ return self._jax_binary_helper(expr, aux, xor_op)
652
+
653
+ def _jax_implies(self, expr, aux):
654
+ if not self.traced.cached_is_fluent(expr):
655
+ return super()._jax_implies(expr, aux)
656
+ aux['overriden'][expr.id] = __class__.__name__
657
+ def implies_op(x, y):
658
+ sample = 1. - jax.nn.relu(x - y)
659
+ if self.use_logic_ste:
660
+ hard_sample = jnp.asarray(jnp.logical_or(x <= 0.5, y > 0.5), dtype=self.REAL)
661
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
662
+ return sample
663
+ return self._jax_binary_helper(expr, aux, implies_op)
664
+
665
+ def _jax_equiv(self, expr, aux):
666
+ if not self.traced.cached_is_fluent(expr):
667
+ return super()._jax_equiv(expr, aux)
668
+ aux['overriden'][expr.id] = __class__.__name__
669
+ def equiv_op(x, y):
670
+ sample = jax.nn.relu(1. - jnp.abs(x - y))
671
+ if self.use_logic_ste:
672
+ hard_sample = jnp.logical_and(
673
+ jnp.logical_or(x <= 0.5, y > 0.5), jnp.logical_or(y <= 0.5, x > 0.5))
674
+ hard_sample = jnp.asarray(hard_sample, dtype=self.REAL)
675
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
676
+ return sample
677
+ return self._jax_binary_helper(expr, aux, equiv_op)
678
+
679
+ def _jax_forall(self, expr, aux):
680
+ if not self.traced.cached_is_fluent(expr):
681
+ return super()._jax_forall(expr, aux)
682
+ aux['overriden'][expr.id] = __class__.__name__
683
+ def forall_op(x, axis):
684
+ sample = jax.nn.relu(jnp.sum(x - 1., axis=axis) + 1.)
685
+ if self.use_logic_ste:
686
+ hard_sample = jnp.all(x, axis=axis)
687
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
688
+ return sample
689
+ return self._jax_aggregation_helper(expr, aux, forall_op)
690
+
691
+ def _jax_exists(self, expr, aux):
692
+ if not self.traced.cached_is_fluent(expr):
693
+ return super()._jax_exists(expr, aux)
694
+ aux['overriden'][expr.id] = __class__.__name__
695
+ def exists_op(x, axis):
696
+ sample = 1. - jax.nn.relu(jnp.sum(-x, axis=axis) + 1.)
697
+ if self.use_logic_ste:
698
+ hard_sample = jnp.any(x, axis=axis)
699
+ sample = sample + jax.lax.stop_gradient(hard_sample - sample)
700
+ return sample
701
+ return self._jax_aggregation_helper(expr, aux, exists_op)
702
+
703
+
704
+ # ===============================================================================
705
+ # function relaxations
706
+ # ===============================================================================
707
+
708
+ class SafeSqrt(JaxRDDLCompilerWithGrad):
709
+ '''Sqrt operation without negative underflow.'''
710
+
711
+ def __init__(self, *args, sqrt_eps: float=1e-14, **kwargs) -> None:
712
+ super(SafeSqrt, self).__init__(*args, **kwargs)
713
+ self.sqrt_eps = float(sqrt_eps)
714
+
715
+ def get_kwargs(self) -> Dict[str, Any]:
716
+ kwargs = super().get_kwargs()
717
+ kwargs['sqrt_eps'] = self.sqrt_eps
718
+ return kwargs
719
+
720
+ def _jax_sqrt(self, expr, aux):
721
+ aux['overriden'][expr.id] = __class__.__name__
722
+ def safe_sqrt_op(x):
723
+ return jnp.sqrt(x + self.sqrt_eps)
724
+ return self._jax_unary_helper(expr, aux, safe_sqrt_op, at_least_int=True)
725
+
726
+
727
+ class SoftFloor(JaxRDDLCompilerWithGrad):
728
+ '''Floor and ceil operations approximated using soft operations.'''
729
+
730
+ def __init__(self, *args, floor_weight: float=10.,
731
+ use_floor_ste: bool=True, **kwargs) -> None:
732
+ super(SoftFloor, self).__init__(*args, **kwargs)
733
+ self.floor_weight = float(floor_weight)
734
+ self.use_floor_ste = use_floor_ste
735
+
736
+ def get_kwargs(self) -> Dict[str, Any]:
737
+ kwargs = super().get_kwargs()
738
+ kwargs['floor_weight'] = self.floor_weight
739
+ kwargs['use_floor_ste'] = self.use_floor_ste
740
+ return kwargs
741
+
742
+ @staticmethod
743
+ def soft_floor(x: jnp.ndarray, w: float) -> jnp.ndarray:
744
+ s = x - jnp.floor(x)
745
+ return jnp.floor(x) + 0.5 * (
746
+ 1. + stable_tanh(w * (s - 1.) / 2.) / stable_tanh(w / 4.))
747
+
748
+ def _jax_floor(self, expr, aux):
749
+ if not self.traced.cached_is_fluent(expr):
750
+ return super()._jax_floor(expr, aux)
751
+ id_ = expr.id
752
+ aux['params'][id_] = self.floor_weight
753
+ aux['overriden'][id_] = __class__.__name__
754
+ def floor_op(x, params):
755
+ sample = self.soft_floor(x, params[id_])
756
+ if self.use_floor_ste:
757
+ sample = sample + jax.lax.stop_gradient(jnp.floor(x) - sample)
758
+ return sample, params
759
+ return self._jax_unary_helper_with_param(expr, aux, floor_op)
760
+
761
+ def _jax_ceil(self, expr, aux):
762
+ if not self.traced.cached_is_fluent(expr):
763
+ return super()._jax_ceil(expr, aux)
764
+ id_ = expr.id
765
+ aux['params'][id_] = self.floor_weight
766
+ aux['overriden'][id_] = __class__.__name__
767
+ def ceil_op(x, params):
768
+ sample = -self.soft_floor(-x, params[id_])
769
+ if self.use_floor_ste:
770
+ sample = sample + jax.lax.stop_gradient(jnp.ceil(x) - sample)
771
+ return sample, params
772
+ return self._jax_unary_helper_with_param(expr, aux, ceil_op)
773
+
774
+ def _jax_div(self, expr, aux):
775
+ if not self.traced.cached_is_fluent(expr):
776
+ return super()._jax_div(expr, aux)
777
+ id_ = expr.id
778
+ aux['params'][id_] = self.floor_weight
779
+ aux['overriden'][id_] = __class__.__name__
780
+ def div_op(x, y, params):
781
+ sample = self.soft_floor(x / y, params[id_])
782
+ if self.use_floor_ste:
783
+ sample = sample + jax.lax.stop_gradient(jnp.floor_divide(x, y) - sample)
784
+ return sample, params
785
+ return self._jax_binary_helper_with_param(expr, aux, div_op)
786
+
787
+ def _jax_mod(self, expr, aux):
788
+ if not self.traced.cached_is_fluent(expr):
789
+ return super()._jax_mod(expr, aux)
790
+ id_ = expr.id
791
+ aux['params'][id_] = self.floor_weight
792
+ aux['overriden'][id_] = __class__.__name__
793
+ def mod_op(x, y, params):
794
+ div = self.soft_floor(x / y, params[id_])
795
+ if self.use_floor_ste:
796
+ div = div + jax.lax.stop_gradient(jnp.floor_divide(x, y) - div)
797
+ sample = x - y * div
798
+ return sample, params
799
+ return self._jax_binary_helper_with_param(expr, aux, mod_op)
800
+
801
+
802
+ class SoftRound(JaxRDDLCompilerWithGrad):
803
+ '''Round operations approximated using soft operations.'''
804
+
805
+ def __init__(self, *args, round_weight: float=10.,
806
+ use_round_ste: bool=True, **kwargs) -> None:
807
+ super(SoftRound, self).__init__(*args, **kwargs)
808
+ self.round_weight = float(round_weight)
809
+ self.use_round_ste = use_round_ste
810
+
811
+ def get_kwargs(self) -> Dict[str, Any]:
812
+ kwargs = super().get_kwargs()
813
+ kwargs['round_weight'] = self.round_weight
814
+ kwargs['use_round_ste'] = self.use_round_ste
815
+ return kwargs
816
+
817
+ def _jax_round(self, expr, aux):
818
+ if not self.traced.cached_is_fluent(expr):
819
+ return super()._jax_round(expr, aux)
820
+ id_ = expr.id
821
+ aux['params'][id_] = self.round_weight
822
+ aux['overriden'][id_] = __class__.__name__
823
+ def round_op(x, params):
184
824
  param = params[id_]
185
825
  m = jnp.floor(x) + 0.5
186
- rounded = m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
187
- return rounded, params
188
- return _jax_wrapped_calc_round_approx
826
+ sample = m + 0.5 * stable_tanh(param * (x - m)) / stable_tanh(param / 2.)
827
+ if self.use_round_ste:
828
+ sample = sample + jax.lax.stop_gradient(jnp.round(x) - sample)
829
+ return sample, params
830
+ return self._jax_unary_helper_with_param(expr, aux, round_op)
189
831
 
190
- def __str__(self) -> str:
191
- return f'SoftFloor and SoftRound with weight {self.weight}'
192
832
 
833
+ # ===============================================================================
834
+ # control flow relaxations
835
+ # ===============================================================================
193
836
 
194
- # ===========================================================================
195
- # LOGICAL COMPLEMENT
196
- # - abstract class
197
- # - standard complement
198
- #
199
- # ===========================================================================
837
+ class LinearIfElse(JaxRDDLCompilerWithGrad):
838
+ '''Approximate if else statement as a linear combination.'''
200
839
 
201
- class Complement(metaclass=ABCMeta):
202
- '''Base class for approximate logical complement operations.'''
203
-
204
- @abstractmethod
205
- def __call__(self, id, init_params):
206
- pass
840
+ def __init__(self, *args, use_if_else_ste: bool=True, **kwargs) -> None:
841
+ super(LinearIfElse, self).__init__(*args, **kwargs)
842
+ self.use_if_else_ste = use_if_else_ste
207
843
 
844
+ def get_kwargs(self) -> Dict[str, Any]:
845
+ kwargs = super().get_kwargs()
846
+ kwargs['use_if_else_ste'] = self.use_if_else_ste
847
+ return kwargs
208
848
 
209
- class StandardComplement(Complement):
210
- '''The standard approximate logical complement given by x -> 1 - x.'''
211
-
212
- # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
213
- @staticmethod
214
- def _jax_wrapped_calc_not_approx(x, params):
215
- return 1.0 - x, params
216
-
217
- def __call__(self, id, init_params):
218
- return self._jax_wrapped_calc_not_approx
219
-
220
- def __str__(self) -> str:
221
- return 'Standard complement'
222
-
849
+ def _jax_if(self, expr, aux):
850
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 3)
851
+ pred, if_true, if_false = expr.args
223
852
 
224
- # ===========================================================================
225
- # TNORMS
226
- # - abstract tnorm
227
- # - product tnorm
228
- # - Godel tnorm
229
- # - Lukasiewicz tnorm
230
- # - Yager(p) tnorm
231
- #
232
- # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
233
- # ===========================================================================
853
+ # if predicate is non-fluent, always use the exact operation
854
+ if not self.traced.cached_is_fluent(pred):
855
+ return super()._jax_if(expr, aux)
856
+
857
+ # recursively compile arguments
858
+ aux['overriden'][expr.id] = __class__.__name__
859
+ jax_pred = self._jax(pred, aux)
860
+ jax_true = self._jax(if_true, aux)
861
+ jax_false = self._jax(if_false, aux)
862
+
863
+ def _jax_wrapped_if_then_else_linear(fls, nfls, params, key):
864
+ sample_pred, key, err1, params = jax_pred(fls, nfls, params, key)
865
+ sample_true, key, err2, params = jax_true(fls, nfls, params, key)
866
+ sample_false, key, err3, params = jax_false(fls, nfls, params, key)
867
+ if self.use_if_else_ste:
868
+ hard_pred = (sample_pred > 0.5).astype(sample_pred.dtype)
869
+ sample_pred = sample_pred + jax.lax.stop_gradient(hard_pred - sample_pred)
870
+ sample = sample_pred * sample_true + (1 - sample_pred) * sample_false
871
+ err = err1 | err2 | err3
872
+ return sample, key, err, params
873
+ return _jax_wrapped_if_then_else_linear
874
+
875
+
876
+ class SoftmaxSwitch(JaxRDDLCompilerWithGrad):
877
+ '''Softmax switch control flow using a probabilistic interpretation.'''
878
+
879
+ def __init__(self, *args, switch_weight: float=10., **kwargs) -> None:
880
+ super(SoftmaxSwitch, self).__init__(*args, **kwargs)
881
+ self.switch_weight = float(switch_weight)
882
+
883
+ def get_kwargs(self) -> Dict[str, Any]:
884
+ kwargs = super().get_kwargs()
885
+ kwargs['switch_weight'] = self.switch_weight
886
+ return kwargs
887
+
888
+ def _jax_switch(self, expr, aux):
889
+
890
+ # if predicate is non-fluent, always use the exact operation
891
+ # case conditions are currently only literals so they are non-fluent
892
+ pred = expr.args[0]
893
+ if not self.traced.cached_is_fluent(pred):
894
+ return super()._jax_switch(expr, aux)
895
+
896
+ id_ = expr.id
897
+ aux['params'][id_] = self.switch_weight
898
+ aux['overriden'][id_] = __class__.__name__
899
+
900
+ # recursively compile predicate
901
+ jax_pred = self._jax(pred, aux)
902
+
903
+ # recursively compile cases
904
+ cases, default = self.traced.cached_sim_info(expr)
905
+ jax_default = None if default is None else self._jax(default, aux)
906
+ jax_cases = [
907
+ (jax_default if _case is None else self._jax(_case, aux))
908
+ for _case in cases
909
+ ]
910
+
911
+ def _jax_wrapped_switch_softmax(fls, nfls, params, key):
912
+
913
+ # sample predicate
914
+ sample_pred, key, err, params = jax_pred(fls, nfls, params, key)
915
+
916
+ # sample cases
917
+ sample_cases = []
918
+ for jax_case in jax_cases:
919
+ sample, key, err_case, params = jax_case(fls, nfls, params, key)
920
+ sample_cases.append(sample)
921
+ err = err | err_case
922
+ sample_cases = jnp.asarray(sample_cases)
923
+ sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
924
+
925
+ # replace integer indexing with softmax
926
+ sample_pred = jnp.broadcast_to(
927
+ sample_pred[jnp.newaxis, ...], shape=jnp.shape(sample_cases))
928
+ literals = enumerate_literals(jnp.shape(sample_cases), axis=0)
929
+ proximity = -jnp.square(sample_pred - literals)
930
+ logits = params[id_] * proximity
931
+ sample = stable_softmax_weight_sum(logits, sample_cases, axis=0)
932
+ return sample, key, err, params
933
+ return _jax_wrapped_switch_softmax
934
+
935
+
936
+ # ===============================================================================
937
+ # distribution relaxations - Geometric
938
+ # ===============================================================================
939
+
940
+ class ReparameterizedGeometric(JaxRDDLCompilerWithGrad):
941
+
942
+ def __init__(self, *args,
943
+ geometric_floor_weight: float=10.,
944
+ geometric_eps: float=1e-14, **kwargs) -> None:
945
+ super(ReparameterizedGeometric, self).__init__(*args, **kwargs)
946
+ self.geometric_floor_weight = float(geometric_floor_weight)
947
+ self.geometric_eps = float(geometric_eps)
948
+
949
+ def get_kwargs(self) -> Dict[str, Any]:
950
+ kwargs = super().get_kwargs()
951
+ kwargs['geometric_floor_weight'] = self.geometric_floor_weight
952
+ kwargs['geometric_eps'] = self.geometric_eps
953
+ return kwargs
954
+
955
+ def _jax_geometric(self, expr, aux):
956
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
957
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
958
+ arg_prob, = expr.args
959
+
960
+ # if prob is non-fluent, always use the exact operation
961
+ if not self.traced.cached_is_fluent(arg_prob):
962
+ return super()._jax_geometric(expr, aux)
963
+
964
+ id_ = expr.id
965
+ aux['params'][id_] = (self.geometric_floor_weight, self.geometric_eps)
966
+ aux['overriden'][id_] = __class__.__name__
967
+
968
+ jax_prob = self._jax(arg_prob, aux)
969
+
970
+ def _jax_wrapped_distribution_geometric_reparam(fls, nfls, params, key):
971
+ w, eps = params[id_]
972
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
973
+ key, subkey = random.split(key)
974
+ U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
975
+ sample = 1. + SoftFloor.soft_floor(jnp.log1p(-U) / jnp.log1p(-prob + eps), w=w)
976
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
977
+ err = err | (out_of_bounds * ERR)
978
+ return sample, key, err, params
979
+ return _jax_wrapped_distribution_geometric_reparam
234
980
 
235
- class TNorm(metaclass=ABCMeta):
236
- '''Base class for fuzzy differentiable t-norms.'''
237
-
238
- @abstractmethod
239
- def norm(self, id, init_params):
240
- '''Elementwise t-norm of x and y.'''
241
- pass
242
-
243
- @abstractmethod
244
- def norms(self, id, init_params):
245
- '''T-norm computed for tensor x along axis.'''
246
- pass
247
-
248
981
 
249
- class ProductTNorm(TNorm):
250
- '''Product t-norm given by the expression (x, y) -> x * y.'''
251
-
252
- @staticmethod
253
- def _jax_wrapped_calc_and_approx(x, y, params):
254
- return x * y, params
982
+ class DeterminizedGeometric(JaxRDDLCompilerWithGrad):
983
+
984
+ def __init__(self, *args, **kwargs) -> None:
985
+ super(DeterminizedGeometric, self).__init__(*args, **kwargs)
986
+
987
+ def get_kwargs(self) -> Dict[str, Any]:
988
+ return super().get_kwargs()
989
+
990
+ def _jax_geometric(self, expr, aux):
991
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
992
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
993
+ arg_prob, = expr.args
255
994
 
256
- def norm(self, id, init_params):
257
- return self._jax_wrapped_calc_and_approx
258
-
259
- @staticmethod
260
- def _jax_wrapped_calc_forall_approx(x, axis, params):
261
- return jnp.prod(x, axis=axis), params
995
+ # if prob is non-fluent, always use the exact operation
996
+ if not self.traced.cached_is_fluent(arg_prob):
997
+ return super()._jax_geometric(expr, aux)
998
+
999
+ aux['overriden'][expr.id] = __class__.__name__
1000
+
1001
+ jax_prob = self._jax(arg_prob, aux)
262
1002
 
263
- def norms(self, id, init_params):
264
- return self._jax_wrapped_calc_forall_approx
1003
+ def _jax_wrapped_distribution_geometric_determinized(fls, nfls, params, key):
1004
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1005
+ sample = 1. / prob
1006
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1007
+ err = err | (out_of_bounds * ERR)
1008
+ return sample, key, err, params
1009
+ return _jax_wrapped_distribution_geometric_determinized
265
1010
 
266
- def __str__(self) -> str:
267
- return 'Product t-norm'
268
-
269
1011
 
270
- class GodelTNorm(TNorm):
271
- '''Godel t-norm given by the expression (x, y) -> min(x, y).'''
272
-
273
- @staticmethod
274
- def _jax_wrapped_calc_and_approx(x, y, params):
275
- return jnp.minimum(x, y), params
276
-
277
- def norm(self, id, init_params):
278
- return self._jax_wrapped_calc_and_approx
279
-
280
- @staticmethod
281
- def _jax_wrapped_calc_forall_approx(x, axis, params):
282
- return jnp.min(x, axis=axis), params
1012
+ # ===============================================================================
1013
+ # distribution relaxations - Bernoulli
1014
+ # ===============================================================================
1015
+
1016
+ class ReparameterizedSigmoidBernoulli(JaxRDDLCompilerWithGrad):
1017
+
1018
+ def __init__(self, *args, bernoulli_sigmoid_weight: float=10., **kwargs) -> None:
1019
+ super(ReparameterizedSigmoidBernoulli, self).__init__(*args, **kwargs)
1020
+ self.bernoulli_sigmoid_weight = float(bernoulli_sigmoid_weight)
283
1021
 
284
- def norms(self, id, init_params):
285
- return self._jax_wrapped_calc_forall_approx
286
-
287
- def __str__(self) -> str:
288
- return 'Godel t-norm'
289
-
1022
+ def get_kwargs(self) -> Dict[str, Any]:
1023
+ kwargs = super().get_kwargs()
1024
+ kwargs['bernoulli_sigmoid_weight'] = self.bernoulli_sigmoid_weight
1025
+ return kwargs
1026
+
1027
+ def _jax_bernoulli(self, expr, aux):
1028
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1029
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1030
+ arg_prob, = expr.args
1031
+
1032
+ # if prob is non-fluent, always use the exact operation
1033
+ if not self.traced.cached_is_fluent(arg_prob):
1034
+ return super()._jax_bernoulli(expr, aux)
1035
+
1036
+ id_ = expr.id
1037
+ aux['params'][id_] = self.bernoulli_sigmoid_weight
1038
+ aux['overriden'][id_] = __class__.__name__
290
1039
 
291
- class LukasiewiczTNorm(TNorm):
292
- '''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
293
-
294
- @staticmethod
295
- def _jax_wrapped_calc_and_approx(x, y, params):
296
- land = jax.nn.relu(x + y - 1.0)
297
- return land, params
1040
+ jax_prob = self._jax(arg_prob, aux)
298
1041
 
299
- def norm(self, id, init_params):
300
- return self._jax_wrapped_calc_and_approx
301
-
302
- @staticmethod
303
- def _jax_wrapped_calc_forall_approx(x, axis, params):
304
- forall = jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
305
- return forall, params
1042
+ def _jax_wrapped_distribution_bernoulli_reparam(fls, nfls, params, key):
1043
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1044
+ key, subkey = random.split(key)
1045
+ U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1046
+ sample = stable_sigmoid(params[id_] * (prob - U))
1047
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1048
+ err = err | (out_of_bounds * ERR)
1049
+ return sample, key, err, params
1050
+ return _jax_wrapped_distribution_bernoulli_reparam
1051
+
1052
+
1053
+ class GumbelSoftmaxBernoulli(JaxRDDLCompilerWithGrad):
1054
+
1055
+ def __init__(self, *args,
1056
+ bernoulli_softmax_weight: float=10.,
1057
+ bernoulli_eps: float=1e-14, **kwargs) -> None:
1058
+ super(GumbelSoftmaxBernoulli, self).__init__(*args, **kwargs)
1059
+ self.bernoulli_softmax_weight = float(bernoulli_softmax_weight)
1060
+ self.bernoulli_eps = float(bernoulli_eps)
1061
+
1062
+ def get_kwargs(self) -> Dict[str, Any]:
1063
+ kwargs = super().get_kwargs()
1064
+ kwargs['bernoulli_softmax_weight'] = self.bernoulli_softmax_weight
1065
+ kwargs['bernoulli_eps'] = self.bernoulli_eps
1066
+ return kwargs
1067
+
1068
+ def _jax_bernoulli(self, expr, aux):
1069
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1070
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1071
+ arg_prob, = expr.args
1072
+
1073
+ # if prob is non-fluent, always use the exact operation
1074
+ if not self.traced.cached_is_fluent(arg_prob):
1075
+ return super()._jax_bernoulli(expr, aux)
306
1076
 
307
- def norms(self, id, init_params):
308
- return self._jax_wrapped_calc_forall_approx
1077
+ id_ = expr.id
1078
+ aux['params'][id_] = (self.bernoulli_softmax_weight, self.bernoulli_eps)
1079
+ aux['overriden'][id_] = __class__.__name__
1080
+
1081
+ jax_prob = self._jax(arg_prob, aux)
1082
+
1083
+ def _jax_wrapped_distribution_bernoulli_gumbel_softmax(fls, nfls, params, key):
1084
+ w, eps = params[id_]
1085
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1086
+ probs = jnp.stack([1. - prob, prob], axis=-1)
1087
+ key, subkey = random.split(key)
1088
+ g = random.gumbel(key=subkey, shape=jnp.shape(probs), dtype=self.REAL)
1089
+ sample = SoftmaxArgmax.soft_argmax(g + jnp.log(probs + eps), w=w, axes=-1)
1090
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1091
+ err = err | (out_of_bounds * ERR)
1092
+ return sample, key, err, params
1093
+ return _jax_wrapped_distribution_bernoulli_gumbel_softmax
309
1094
 
310
- def __str__(self) -> str:
311
- return 'Lukasiewicz t-norm'
312
1095
 
1096
+ class DeterminizedBernoulli(JaxRDDLCompilerWithGrad):
313
1097
 
314
- class YagerTNorm(TNorm):
315
- '''Yager t-norm given by the expression
316
- (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
317
-
318
- def __init__(self, p: float=2.0) -> None:
319
- self.p = float(p)
320
-
321
- def norm(self, id, init_params):
322
- id_ = str(id)
323
- init_params[id_] = self.p
324
- def _jax_wrapped_calc_and_approx(x, y, params):
325
- base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
326
- arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
327
- land = jax.nn.relu(1.0 - arg)
328
- return land, params
329
- return _jax_wrapped_calc_and_approx
330
-
331
- def norms(self, id, init_params):
332
- id_ = str(id)
333
- init_params[id_] = self.p
334
- def _jax_wrapped_calc_forall_approx(x, axis, params):
335
- arg = jax.nn.relu(1.0 - x)
336
- for ax in sorted(axis, reverse=True):
337
- arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
338
- forall = jax.nn.relu(1.0 - arg)
339
- return forall, params
340
- return _jax_wrapped_calc_forall_approx
341
-
342
- def __str__(self) -> str:
343
- return f'Yager({self.p}) t-norm'
344
-
1098
+ def __init__(self, *args, **kwargs) -> None:
1099
+ super(DeterminizedBernoulli, self).__init__(*args, **kwargs)
345
1100
 
346
- # ===========================================================================
347
- # RANDOM SAMPLING
348
- # - abstract sampler
349
- # - Gumbel-softmax sampler
350
- # - determinization
351
- #
352
- # ===========================================================================
1101
+ def get_kwargs(self) -> Dict[str, Any]:
1102
+ return super().get_kwargs()
353
1103
 
354
- class RandomSampling(metaclass=ABCMeta):
355
- '''Describes how non-reparameterizable random variables are sampled.'''
356
-
357
- @abstractmethod
358
- def discrete(self, id, init_params, logic):
359
- pass
360
-
361
- @abstractmethod
362
- def poisson(self, id, init_params, logic):
363
- pass
364
-
365
- @abstractmethod
366
- def binomial(self, id, init_params, logic):
367
- pass
368
-
369
- @abstractmethod
370
- def negative_binomial(self, id, init_params, logic):
371
- pass
372
-
373
- @abstractmethod
374
- def geometric(self, id, init_params, logic):
375
- pass
376
-
377
- @abstractmethod
378
- def bernoulli(self, id, init_params, logic):
379
- pass
380
-
381
- def __str__(self) -> str:
382
- return 'RandomSampling'
1104
+ def _jax_bernoulli(self, expr, aux):
1105
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BERNOULLI']
1106
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1107
+ arg_prob, = expr.args
1108
+
1109
+ # if prob is non-fluent, always use the exact operation
1110
+ if not self.traced.cached_is_fluent(arg_prob):
1111
+ return super()._jax_bernoulli(expr, aux)
1112
+
1113
+ aux['overriden'][expr.id] = __class__.__name__
383
1114
 
1115
+ jax_prob = self._jax(arg_prob, aux)
1116
+
1117
+ def _jax_wrapped_distribution_bernoulli_determinized(fls, nfls, params, key):
1118
+ prob, key, err, params = jax_prob(fls, nfls, params, key)
1119
+ sample = prob
1120
+ out_of_bounds = jnp.logical_not(jnp.all(jnp.logical_and(prob >= 0, prob <= 1)))
1121
+ err = err | (out_of_bounds * ERR)
1122
+ return sample, key, err, params
1123
+ return _jax_wrapped_distribution_bernoulli_determinized
1124
+
1125
+
1126
+ # ===============================================================================
1127
+ # distribution relaxations - Discrete
1128
+ # ===============================================================================
1129
+
1130
+ # https://arxiv.org/pdf/1611.01144
1131
+ class GumbelSoftmaxDiscrete(JaxRDDLCompilerWithGrad):
1132
+
1133
+ def __init__(self, *args,
1134
+ discrete_softmax_weight: float=10.,
1135
+ discrete_eps: float=1e-14, **kwargs) -> None:
1136
+ super(GumbelSoftmaxDiscrete, self).__init__(*args, **kwargs)
1137
+ self.discrete_softmax_weight = float(discrete_softmax_weight)
1138
+ self.discrete_eps = float(discrete_eps)
1139
+
1140
+ def get_kwargs(self) -> Dict[str, Any]:
1141
+ kwargs = super().get_kwargs()
1142
+ kwargs['discrete_softmax_weight'] = self.discrete_softmax_weight
1143
+ kwargs['discrete_eps'] = self.discrete_eps
1144
+ return kwargs
1145
+
1146
+ def _jax_discrete(self, expr, aux, unnorm):
1147
+
1148
+ # if all probabilities are non-fluent, then always sample exact
1149
+ ordered_args = self.traced.cached_sim_info(expr)
1150
+ if not any(self.traced.cached_is_fluent(arg) for arg in ordered_args):
1151
+ return super()._jax_discrete(expr, aux)
1152
+
1153
+ id_ = expr.id
1154
+ aux['params'][id_] = (self.discrete_softmax_weight, self.discrete_eps)
1155
+ aux['overriden'][id_] = __class__.__name__
384
1156
 
385
- class SoftRandomSampling(RandomSampling):
386
- '''Random sampling of discrete variables using Gumbel-softmax trick.'''
387
-
388
- def __init__(self, poisson_max_bins: int=100,
389
- poisson_min_cdf: float=0.999,
390
- poisson_exp_sampling: bool=True,
391
- binomial_max_bins: int=100,
392
- bernoulli_gumbel_softmax: bool=False) -> None:
393
- '''Creates a new instance of soft random sampling.
394
-
395
- :param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
396
- :param poisson_min_cdf: minimum cdf value of Poisson within truncated region
397
- in order to use Poisson relaxation
398
- :param poisson_exp_sampling: whether to use Poisson process sampling method
399
- instead of truncated Gumbel-Softmax
400
- :param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
401
- :param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
402
- Bernoulli samples, or the standard uniform reparameterization instead
403
- '''
404
- self.poisson_bins = poisson_max_bins
405
- self.poisson_min_cdf = poisson_min_cdf
406
- self.poisson_exp_method = poisson_exp_sampling
407
- self.binomial_bins = binomial_max_bins
408
- self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
409
-
410
- # https://arxiv.org/pdf/1611.01144
411
- def discrete(self, id, init_params, logic):
412
- argmax_approx = logic.argmax(id, init_params)
413
- def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
414
- Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
415
- sample = Gumbel01 + jnp.log(prob + logic.eps)
416
- return argmax_approx(sample, axis=-1, params=params)
417
- return _jax_wrapped_calc_discrete_gumbel_softmax
418
-
419
- def _poisson_gumbel_softmax(self, id, init_params, logic):
420
- argmax_approx = logic.argmax(id, init_params)
421
- def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
422
- ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
423
- rate = rate[..., jnp.newaxis]
424
- log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
425
- Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
426
- sample = Gumbel01 + log_prob
427
- return argmax_approx(sample, axis=-1, params=params)
428
- return _jax_wrapped_calc_poisson_gumbel_softmax
429
-
430
- # https://arxiv.org/abs/2405.14473
431
- def _poisson_exponential(self, id, init_params, logic):
432
- less_approx = logic.less(id, init_params)
433
- def _jax_wrapped_calc_poisson_exponential(key, rate, params):
434
- Exp1 = random.exponential(
435
- key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
436
- delta_t = Exp1 / rate[jnp.newaxis, ...]
437
- times = jnp.cumsum(delta_t, axis=0)
438
- indicator, params = less_approx(times, 1.0, params)
439
- sample = jnp.sum(indicator, axis=0)
440
- return sample, params
441
- return _jax_wrapped_calc_poisson_exponential
442
-
443
- # normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
444
- def _poisson_normal_approx(self, logic):
445
- def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
446
- normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
447
- sample = rate + jnp.sqrt(rate) * normal
448
- return sample, params
449
- return _jax_wrapped_calc_poisson_normal_approx
450
-
451
- def poisson(self, id, init_params, logic):
452
- if self.poisson_exp_method:
453
- _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
454
- id, init_params, logic)
455
- else:
456
- _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
457
- id, init_params, logic)
458
- _jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
459
-
460
- # for small rate use the Poisson process or gumbel-softmax reparameterization
461
- # for large rate use the normal approximation
462
- def _jax_wrapped_calc_poisson_approx(key, rate, params):
463
- if self.poisson_bins > 0:
464
- cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
465
- small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
466
- small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
467
- large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
468
- sample = jnp.where(small_rate, small_sample, large_sample)
469
- return sample, params
470
- else:
471
- return _jax_wrapped_calc_poisson_normal(key, rate, params)
472
- return _jax_wrapped_calc_poisson_approx
473
-
474
- # normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
475
- def _binomial_normal_approx(self, logic):
476
- def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
477
- normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
478
- mean = trials * prob
479
- std = jnp.sqrt(trials * prob * (1.0 - prob))
480
- sample = mean + std * normal
481
- return sample, params
482
- return _jax_wrapped_calc_binomial_normal_approx
483
-
484
- def _binomial_gumbel_softmax(self, id, init_params, logic):
485
- argmax_approx = logic.argmax(id, init_params)
486
- def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
487
- ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
488
- trials = trials[..., jnp.newaxis]
489
- prob = prob[..., jnp.newaxis]
490
- in_support = ks <= trials
491
- ks = jnp.minimum(ks, trials)
492
- log_prob = ((scipy.special.gammaln(trials + 1) -
493
- scipy.special.gammaln(ks + 1) -
494
- scipy.special.gammaln(trials - ks + 1)) +
495
- ks * jnp.log(prob + logic.eps) +
496
- (trials - ks) * jnp.log1p(-prob + logic.eps))
497
- log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
498
- Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
499
- sample = Gumbel01 + log_prob
500
- return argmax_approx(sample, axis=-1, params=params)
501
- return _jax_wrapped_calc_binomial_gumbel_softmax
502
-
503
- def binomial(self, id, init_params, logic):
504
- _jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
505
- _jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
506
-
507
- # for small trials use the Bernoulli relaxation
508
- # for large trials use the normal approximation
509
- def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
510
- small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
511
- small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
512
- large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
513
- sample = jnp.where(small_trials, small_sample, large_sample)
514
- return sample, params
515
- return _jax_wrapped_calc_binomial_approx
516
-
517
- # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
518
- def negative_binomial(self, id, init_params, logic):
519
- poisson_approx = self.poisson(id, init_params, logic)
520
- def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
1157
+ jax_probs = [self._jax(arg, aux) for arg in ordered_args]
1158
+ prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
1159
+
1160
+ def _jax_wrapped_distribution_discrete_gumbel_softmax(fls, nfls, params, key):
1161
+ w, eps = params[id_]
1162
+ prob, key, err, params = prob_fn(fls, nfls, params, key)
521
1163
  key, subkey = random.split(key)
522
- trials = jnp.asarray(trials, dtype=logic.REAL)
523
- Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
524
- scale = (1.0 - prob) / prob
525
- poisson_rate = scale * Gamma
526
- return poisson_approx(subkey, poisson_rate, params)
527
- return _jax_wrapped_calc_negative_binomial_approx
528
-
529
- def geometric(self, id, init_params, logic):
530
- approx_floor = logic.floor(id, init_params)
531
- def _jax_wrapped_calc_geometric_approx(key, prob, params):
532
- U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
533
- floor, params = approx_floor(
534
- jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
535
- sample = floor + 1
536
- return sample, params
537
- return _jax_wrapped_calc_geometric_approx
538
-
539
- def _bernoulli_uniform(self, id, init_params, logic):
540
- less_approx = logic.less(id, init_params)
541
- def _jax_wrapped_calc_bernoulli_uniform(key, prob, params):
542
- U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
543
- return less_approx(U, prob, params)
544
- return _jax_wrapped_calc_bernoulli_uniform
545
-
546
- def _bernoulli_gumbel_softmax(self, id, init_params, logic):
547
- discrete_approx = self.discrete(id, init_params, logic)
548
- def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
549
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
550
- return discrete_approx(key, prob, params)
551
- return _jax_wrapped_calc_bernoulli_gumbel_softmax
552
-
553
- def bernoulli(self, id, init_params, logic):
554
- if self.bernoulli_gumbel_softmax:
555
- return self._bernoulli_gumbel_softmax(id, init_params, logic)
556
- else:
557
- return self._bernoulli_uniform(id, init_params, logic)
558
-
559
- def __str__(self) -> str:
560
- return 'SoftRandomSampling'
561
-
1164
+ g = random.gumbel(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1165
+ sample = SoftmaxArgmax.soft_argmax(g + jnp.log(prob + eps), w=w, axes=-1)
1166
+ err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
1167
+ return sample, key, err, params
1168
+ return _jax_wrapped_distribution_discrete_gumbel_softmax
1169
+
1170
+ def _jax_discrete_pvar(self, expr, aux, unnorm):
1171
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1172
+ _, args = expr.args
1173
+ arg, = args
1174
+
1175
+ # if all probabilities are non-fluent, then always sample exact
1176
+ if not self.traced.cached_is_fluent(arg):
1177
+ return super()._jax_discrete_pvar(expr, aux)
1178
+
1179
+ id_ = expr.id
1180
+ aux['params'][id_] = (self.discrete_softmax_weight, self.discrete_eps)
1181
+ aux['overriden'][id_] = __class__.__name__
1182
+
1183
+ jax_probs = self._jax(arg, aux)
1184
+ prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
562
1185
 
563
- class Determinization(RandomSampling):
564
- '''Random sampling of variables using their deterministic mean estimate.'''
1186
+ def _jax_wrapped_distribution_discrete_pvar_gumbel_softmax(fls, nfls, params, key):
1187
+ w, eps = params[id_]
1188
+ prob, key, err, params = prob_fn(fls, nfls, params, key)
1189
+ key, subkey = random.split(key)
1190
+ g = random.gumbel(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1191
+ sample = SoftmaxArgmax.soft_argmax(g + jnp.log(prob + eps), w=w, axes=-1)
1192
+ err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
1193
+ return sample, key, err, params
1194
+ return _jax_wrapped_distribution_discrete_pvar_gumbel_softmax
565
1195
 
566
- @staticmethod
567
- def _jax_wrapped_calc_discrete_determinized(key, prob, params):
568
- literals = enumerate_literals(jnp.shape(prob), axis=-1)
569
- sample = jnp.sum(literals * prob, axis=-1)
570
- return sample, params
571
-
572
- def discrete(self, id, init_params, logic):
573
- return self._jax_wrapped_calc_discrete_determinized
574
-
575
- @staticmethod
576
- def _jax_wrapped_calc_poisson_determinized(key, rate, params):
577
- return rate, params
578
1196
 
579
- def poisson(self, id, init_params, logic):
580
- return self._jax_wrapped_calc_poisson_determinized
581
-
582
- @staticmethod
583
- def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
584
- sample = trials * prob
585
- return sample, params
586
-
587
- def binomial(self, id, init_params, logic):
588
- return self._jax_wrapped_calc_binomial_determinized
1197
+ class DeterminizedDiscrete(JaxRDDLCompilerWithGrad):
589
1198
 
590
- @staticmethod
591
- def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
592
- sample = trials * ((1.0 / prob) - 1.0)
593
- return sample, params
1199
+ def __init__(self, *args, **kwargs) -> None:
1200
+ super(DeterminizedDiscrete, self).__init__(*args, **kwargs)
594
1201
 
595
- def negative_binomial(self, id, init_params, logic):
596
- return self._jax_wrapped_calc_negative_binomial_determinized
597
-
598
- @staticmethod
599
- def _jax_wrapped_calc_geometric_determinized(key, prob, params):
600
- sample = 1.0 / prob
601
- return sample, params
602
-
603
- def geometric(self, id, init_params, logic):
604
- return self._jax_wrapped_calc_geometric_determinized
605
-
606
- @staticmethod
607
- def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
608
- sample = prob
609
- return sample, params
610
-
611
- def bernoulli(self, id, init_params, logic):
612
- return self._jax_wrapped_calc_bernoulli_determinized
1202
+ def get_kwargs(self) -> Dict[str, Any]:
1203
+ return super().get_kwargs()
613
1204
 
614
- def __str__(self) -> str:
615
- return 'Deterministic'
616
-
1205
+ def _jax_discrete(self, expr, aux, unnorm):
1206
+
1207
+ # if all probabilities are non-fluent, then always sample exact
1208
+ ordered_args = self.traced.cached_sim_info(expr)
1209
+ if not any(self.traced.cached_is_fluent(arg) for arg in ordered_args):
1210
+ return super()._jax_discrete(expr, aux)
1211
+
1212
+ aux['overriden'][expr.id] = __class__.__name__
1213
+
1214
+ jax_probs = [self._jax(arg, aux) for arg in ordered_args]
1215
+ prob_fn = self._jax_discrete_prob(jax_probs, unnorm)
1216
+
1217
+ def _jax_wrapped_distribution_discrete_determinized(fls, nfls, params, key):
1218
+ prob, key, err, params = prob_fn(fls, nfls, params, key)
1219
+ literals = enumerate_literals(jnp.shape(prob), axis=-1)
1220
+ sample = jnp.sum(literals * prob, axis=-1)
1221
+ err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
1222
+ return sample, key, err, params
1223
+ return _jax_wrapped_distribution_discrete_determinized
1224
+
1225
+ def _jax_discrete_pvar(self, expr, aux, unnorm):
1226
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1227
+ _, args = expr.args
1228
+ arg, = args
1229
+
1230
+ # if all probabilities are non-fluent, then always sample exact
1231
+ if not self.traced.cached_is_fluent(arg):
1232
+ return super()._jax_discrete_pvar(expr, aux)
1233
+
1234
+ aux['overriden'][expr.id] = __class__.__name__
1235
+
1236
+ jax_probs = self._jax(arg, aux)
1237
+ prob_fn = self._jax_discrete_pvar_prob(jax_probs, unnorm)
1238
+
1239
+ def _jax_wrapped_distribution_discrete_pvar_determinized(fls, nfls, params, key):
1240
+ prob, key, err, params = prob_fn(fls, nfls, params, key)
1241
+ literals = enumerate_literals(jnp.shape(prob), axis=-1)
1242
+ sample = jnp.sum(literals * prob, axis=-1)
1243
+ err = JaxRDDLCompilerWithGrad._jax_update_discrete_oob_error(err, prob)
1244
+ return sample, key, err, params
1245
+ return _jax_wrapped_distribution_discrete_pvar_determinized
1246
+
1247
+
1248
+ # ===============================================================================
1249
+ # distribution relaxations - Binomial
1250
+ # ===============================================================================
1251
+
1252
+ class GumbelSoftmaxBinomial(JaxRDDLCompilerWithGrad):
1253
+
1254
+ def __init__(self, *args,
1255
+ binomial_nbins: int=100,
1256
+ binomial_softmax_weight: float=10.,
1257
+ binomial_eps: float=1e-14, **kwargs) -> None:
1258
+ super(GumbelSoftmaxBinomial, self).__init__(*args, **kwargs)
1259
+ self.binomial_nbins = binomial_nbins
1260
+ self.binomial_softmax_weight = float(binomial_softmax_weight)
1261
+ self.binomial_eps = float(binomial_eps)
1262
+
1263
+ def get_kwargs(self) -> Dict[str, Any]:
1264
+ kwargs = super().get_kwargs()
1265
+ kwargs['binomial_nbins'] = self.binomial_nbins
1266
+ kwargs['binomial_softmax_weight'] = self.binomial_softmax_weight
1267
+ kwargs['binomial_eps'] = self.binomial_eps
1268
+ return kwargs
617
1269
 
618
- # ===========================================================================
619
- # CONTROL FLOW
620
- # - soft flow
621
- #
622
- # ===========================================================================
1270
+ # normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
1271
+ @staticmethod
1272
+ def normal_approx_to_binomial(key: random.PRNGKey,
1273
+ trials: jnp.ndarray, prob: jnp.ndarray) -> jnp.ndarray:
1274
+ normal = random.normal(key=key, shape=jnp.shape(trials), dtype=prob.dtype)
1275
+ mean = trials * prob
1276
+ std = jnp.sqrt(trials * jnp.clip(prob * (1.0 - prob), 0.0, 1.0))
1277
+ return mean + std * normal
1278
+
1279
+ def gumbel_softmax_approx_to_binomial(self, key: random.PRNGKey,
1280
+ trials: jnp.ndarray, prob: jnp.ndarray,
1281
+ w: float, eps: float):
1282
+ ks = jnp.arange(self.binomial_nbins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
1283
+ trials = trials[..., jnp.newaxis]
1284
+ prob = prob[..., jnp.newaxis]
1285
+ in_support = ks <= trials
1286
+ ks = jnp.minimum(ks, trials)
1287
+ log_prob = ((scipy.special.gammaln(trials + 1) -
1288
+ scipy.special.gammaln(ks + 1) -
1289
+ scipy.special.gammaln(trials - ks + 1)) +
1290
+ ks * jnp.log(prob + eps) +
1291
+ (trials - ks) * jnp.log1p(-prob + eps))
1292
+ log_prob = jnp.where(in_support, log_prob, jnp.log(eps))
1293
+ g = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=prob.dtype)
1294
+ return SoftmaxArgmax.soft_argmax(g + log_prob, w=w, axes=-1)
1295
+
1296
+ def _jax_binomial(self, expr, aux):
1297
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1298
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1299
+ arg_trials, arg_prob = expr.args
1300
+
1301
+ # if prob is non-fluent, always use the exact operation
1302
+ if (not self.traced.cached_is_fluent(arg_trials) and
1303
+ not self.traced.cached_is_fluent(arg_prob)):
1304
+ return super()._jax_binomial(expr, aux)
1305
+
1306
+ id_ = expr.id
1307
+ aux['params'][id_] = (self.binomial_softmax_weight, self.binomial_eps)
1308
+ aux['overriden'][id_] = __class__.__name__
1309
+
1310
+ # recursively compile arguments
1311
+ jax_trials = self._jax(arg_trials, aux)
1312
+ jax_prob = self._jax(arg_prob, aux)
623
1313
 
624
- class ControlFlow(metaclass=ABCMeta):
625
- '''A base class for control flow, including if and switch statements.'''
626
-
627
- @abstractmethod
628
- def if_then_else(self, id, init_params):
629
- pass
630
-
631
- @abstractmethod
632
- def switch(self, id, init_params):
633
- pass
1314
+ def _jax_wrapped_distribution_binomial_gumbel_softmax(fls, nfls, params, key):
1315
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1316
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1317
+ key, subkey = random.split(key)
1318
+ trials = jnp.asarray(trials, dtype=self.REAL)
1319
+ prob = jnp.asarray(prob, dtype=self.REAL)
634
1320
 
1321
+ # use the gumbel-softmax trick for small population size
1322
+ # use the normal approximation for large population size
1323
+ sample = jnp.where(
1324
+ jax.lax.stop_gradient(trials < self.binomial_nbins),
1325
+ self.gumbel_softmax_approx_to_binomial(subkey, trials, prob, *params[id_]),
1326
+ self.normal_approx_to_binomial(subkey, trials, prob)
1327
+ )
635
1328
 
636
- class SoftControlFlow(ControlFlow):
637
- '''Soft control flow using a probabilistic interpretation.'''
1329
+ out_of_bounds = jnp.logical_not(jnp.all(
1330
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
1331
+ err = err1 | err2 | (out_of_bounds * ERR)
1332
+ return sample, key, err, params
1333
+ return _jax_wrapped_distribution_binomial_gumbel_softmax
638
1334
 
639
- def __init__(self, weight: float=10.0) -> None:
640
- self.weight = float(weight)
641
-
642
- @staticmethod
643
- def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
644
- sample = c * a + (1.0 - c) * b
645
- return sample, params
646
-
647
- def if_then_else(self, id, init_params):
648
- return self._jax_wrapped_calc_if_then_else_soft
649
-
650
- def switch(self, id, init_params):
651
- id_ = str(id)
652
- init_params[id_] = self.weight
653
- def _jax_wrapped_calc_switch_soft(pred, cases, params):
654
- literals = enumerate_literals(jnp.shape(cases), axis=0)
655
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
656
- proximity = -jnp.square(pred - literals)
657
- softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
658
- sample = jnp.sum(cases * softcase, axis=0)
659
- return sample, params
660
- return _jax_wrapped_calc_switch_soft
661
-
662
- def __str__(self) -> str:
663
- return f'Soft control flow with weight {self.weight}'
664
-
665
-
666
- # ===========================================================================
667
- # LOGIC
668
- # - exact logic
669
- # - fuzzy logic
670
- #
671
- # ===========================================================================
672
1335
 
1336
+ class DeterminizedBinomial(JaxRDDLCompilerWithGrad):
673
1337
 
674
- class Logic(metaclass=ABCMeta):
675
- '''A base class for representing logic computations in JAX.'''
676
-
677
- def __init__(self, use64bit: bool=False) -> None:
678
- self.set_use64bit(use64bit)
679
-
680
- def summarize_hyperparameters(self) -> str:
681
- return (f'model relaxation:\n'
682
- f' use_64_bit ={self.use64bit}')
683
-
684
- def set_use64bit(self, use64bit: bool) -> None:
685
- '''Toggles whether or not the JAX system will use 64 bit precision.'''
686
- self.use64bit = use64bit
687
- if use64bit:
688
- self.REAL = jnp.float64
689
- self.INT = jnp.int64
690
- jax.config.update('jax_enable_x64', True)
691
- else:
692
- self.REAL = jnp.float32
693
- self.INT = jnp.int32
694
- jax.config.update('jax_enable_x64', False)
695
-
696
- @staticmethod
697
- def wrap_logic(func):
698
- def exact_func(id, init_params):
699
- return func
700
- return exact_func
701
-
702
- def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
703
- '''Returns a dictionary of all operators in the current logic.'''
704
- return {
705
- 'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
706
- 'arithmetic': {
707
- '+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
708
- '-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
709
- '*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
710
- '/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
711
- },
712
- 'relational': {
713
- '>=': self.greater_equal,
714
- '<=': self.less_equal,
715
- '<': self.less,
716
- '>': self.greater,
717
- '==': self.equal,
718
- '~=': self.not_equal
719
- },
720
- 'logical_not': self.logical_not,
721
- 'logical': {
722
- '^': self.logical_and,
723
- '&': self.logical_and,
724
- '|': self.logical_or,
725
- '~': self.xor,
726
- '=>': self.implies,
727
- '<=>': self.equiv
728
- },
729
- 'aggregation': {
730
- 'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
731
- 'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
732
- 'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
733
- 'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
734
- 'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
735
- 'forall': self.forall,
736
- 'exists': self.exists,
737
- 'argmin': self.argmin,
738
- 'argmax': self.argmax
739
- },
740
- 'unary': {
741
- 'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
742
- 'sgn': self.sgn,
743
- 'round': self.round,
744
- 'floor': self.floor,
745
- 'ceil': self.ceil,
746
- 'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
747
- 'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
748
- 'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
749
- 'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
750
- 'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
751
- 'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
752
- 'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
753
- 'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
754
- 'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
755
- 'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
756
- 'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
757
- 'sqrt': self.sqrt,
758
- 'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
759
- 'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
760
- },
761
- 'binary': {
762
- 'div': self.div,
763
- 'mod': self.mod,
764
- 'fmod': self.mod,
765
- 'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
766
- 'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
767
- 'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
768
- 'log': self.wrap_logic(ExactLogic.exact_binary_log),
769
- 'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
770
- },
771
- 'control': {
772
- 'if': self.control_if,
773
- 'switch': self.control_switch
774
- },
775
- 'sampling': {
776
- 'Bernoulli': self.bernoulli,
777
- 'Discrete': self.discrete,
778
- 'Poisson': self.poisson,
779
- 'Geometric': self.geometric,
780
- 'Binomial': self.binomial,
781
- 'NegativeBinomial': self.negative_binomial
782
- }
783
- }
784
-
785
- # ===========================================================================
786
- # logical operators
787
- # ===========================================================================
788
-
789
- @abstractmethod
790
- def logical_and(self, id, init_params):
791
- pass
792
-
793
- @abstractmethod
794
- def logical_not(self, id, init_params):
795
- pass
796
-
797
- @abstractmethod
798
- def logical_or(self, id, init_params):
799
- pass
800
-
801
- @abstractmethod
802
- def xor(self, id, init_params):
803
- pass
804
-
805
- @abstractmethod
806
- def implies(self, id, init_params):
807
- pass
808
-
809
- @abstractmethod
810
- def equiv(self, id, init_params):
811
- pass
812
-
813
- @abstractmethod
814
- def forall(self, id, init_params):
815
- pass
816
-
817
- @abstractmethod
818
- def exists(self, id, init_params):
819
- pass
820
-
821
- # ===========================================================================
822
- # comparison operators
823
- # ===========================================================================
824
-
825
- @abstractmethod
826
- def greater_equal(self, id, init_params):
827
- pass
828
-
829
- @abstractmethod
830
- def greater(self, id, init_params):
831
- pass
832
-
833
- @abstractmethod
834
- def less_equal(self, id, init_params):
835
- pass
836
-
837
- @abstractmethod
838
- def less(self, id, init_params):
839
- pass
840
-
841
- @abstractmethod
842
- def equal(self, id, init_params):
843
- pass
844
-
845
- @abstractmethod
846
- def not_equal(self, id, init_params):
847
- pass
848
-
849
- # ===========================================================================
850
- # special functions
851
- # ===========================================================================
852
-
853
- @abstractmethod
854
- def sgn(self, id, init_params):
855
- pass
856
-
857
- @abstractmethod
858
- def floor(self, id, init_params):
859
- pass
860
-
861
- @abstractmethod
862
- def round(self, id, init_params):
863
- pass
864
-
865
- @abstractmethod
866
- def ceil(self, id, init_params):
867
- pass
868
-
869
- @abstractmethod
870
- def div(self, id, init_params):
871
- pass
872
-
873
- @abstractmethod
874
- def mod(self, id, init_params):
875
- pass
876
-
877
- @abstractmethod
878
- def sqrt(self, id, init_params):
879
- pass
880
-
881
- # ===========================================================================
882
- # indexing
883
- # ===========================================================================
884
-
885
- @abstractmethod
886
- def argmax(self, id, init_params):
887
- pass
888
-
889
- @abstractmethod
890
- def argmin(self, id, init_params):
891
- pass
892
-
893
- # ===========================================================================
894
- # control flow
895
- # ===========================================================================
896
-
897
- @abstractmethod
898
- def control_if(self, id, init_params):
899
- pass
900
-
901
- @abstractmethod
902
- def control_switch(self, id, init_params):
903
- pass
904
-
905
- # ===========================================================================
906
- # random variables
907
- # ===========================================================================
908
-
909
- @abstractmethod
910
- def discrete(self, id, init_params):
911
- pass
912
-
913
- @abstractmethod
914
- def bernoulli(self, id, init_params):
915
- pass
916
-
917
- @abstractmethod
918
- def poisson(self, id, init_params):
919
- pass
920
-
921
- @abstractmethod
922
- def geometric(self, id, init_params):
923
- pass
924
-
925
- @abstractmethod
926
- def binomial(self, id, init_params):
927
- pass
1338
+ def __init__(self, *args, **kwargs) -> None:
1339
+ super(DeterminizedBinomial, self).__init__(*args, **kwargs)
928
1340
 
929
- @abstractmethod
930
- def negative_binomial(self, id, init_params):
931
- pass
1341
+ def get_kwargs(self) -> Dict[str, Any]:
1342
+ return super().get_kwargs()
932
1343
 
1344
+ def _jax_binomial(self, expr, aux):
1345
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1346
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1347
+ arg_trials, arg_prob = expr.args
933
1348
 
934
- class ExactLogic(Logic):
935
- '''A class representing exact logic in JAX.'''
936
-
937
- @staticmethod
938
- def exact_unary_function(op):
939
- def _jax_wrapped_calc_unary_function_exact(x, params):
940
- return op(x), params
941
- return _jax_wrapped_calc_unary_function_exact
942
-
943
- @staticmethod
944
- def exact_binary_function(op):
945
- def _jax_wrapped_calc_binary_function_exact(x, y, params):
946
- return op(x, y), params
947
- return _jax_wrapped_calc_binary_function_exact
948
-
949
- @staticmethod
950
- def exact_aggregation(op):
951
- def _jax_wrapped_calc_aggregation_exact(x, axis, params):
952
- return op(x, axis=axis), params
953
- return _jax_wrapped_calc_aggregation_exact
954
-
955
- # ===========================================================================
956
- # logical operators
957
- # ===========================================================================
958
-
959
- def logical_and(self, id, init_params):
960
- return self.exact_binary_function(jnp.logical_and)
961
-
962
- def logical_not(self, id, init_params):
963
- return self.exact_unary_function(jnp.logical_not)
964
-
965
- def logical_or(self, id, init_params):
966
- return self.exact_binary_function(jnp.logical_or)
967
-
968
- def xor(self, id, init_params):
969
- return self.exact_binary_function(jnp.logical_xor)
970
-
971
- @staticmethod
972
- def _jax_wrapped_calc_implies_exact(x, y, params):
973
- return jnp.logical_or(jnp.logical_not(x), y), params
974
-
975
- def implies(self, id, init_params):
976
- return self._jax_wrapped_calc_implies_exact
977
-
978
- def equiv(self, id, init_params):
979
- return self.exact_binary_function(jnp.equal)
980
-
981
- def forall(self, id, init_params):
982
- return self.exact_aggregation(jnp.all)
983
-
984
- def exists(self, id, init_params):
985
- return self.exact_aggregation(jnp.any)
986
-
987
- # ===========================================================================
988
- # comparison operators
989
- # ===========================================================================
990
-
991
- def greater_equal(self, id, init_params):
992
- return self.exact_binary_function(jnp.greater_equal)
993
-
994
- def greater(self, id, init_params):
995
- return self.exact_binary_function(jnp.greater)
996
-
997
- def less_equal(self, id, init_params):
998
- return self.exact_binary_function(jnp.less_equal)
999
-
1000
- def less(self, id, init_params):
1001
- return self.exact_binary_function(jnp.less)
1002
-
1003
- def equal(self, id, init_params):
1004
- return self.exact_binary_function(jnp.equal)
1005
-
1006
- def not_equal(self, id, init_params):
1007
- return self.exact_binary_function(jnp.not_equal)
1008
-
1009
- # ===========================================================================
1010
- # special functions
1011
- # ===========================================================================
1012
-
1013
- @staticmethod
1014
- def exact_binary_log(x, y, params):
1015
- return jnp.log(x) / jnp.log(y), params
1016
-
1017
- def sgn(self, id, init_params):
1018
- return self.exact_unary_function(jnp.sign)
1019
-
1020
- def floor(self, id, init_params):
1021
- return self.exact_unary_function(jnp.floor)
1022
-
1023
- def round(self, id, init_params):
1024
- return self.exact_unary_function(jnp.round)
1025
-
1026
- def ceil(self, id, init_params):
1027
- return self.exact_unary_function(jnp.ceil)
1028
-
1029
- def div(self, id, init_params):
1030
- return self.exact_binary_function(jnp.floor_divide)
1031
-
1032
- def mod(self, id, init_params):
1033
- return self.exact_binary_function(jnp.mod)
1034
-
1035
- def sqrt(self, id, init_params):
1036
- return self.exact_unary_function(jnp.sqrt)
1037
-
1038
- # ===========================================================================
1039
- # indexing
1040
- # ===========================================================================
1041
-
1042
- def argmax(self, id, init_params):
1043
- return self.exact_aggregation(jnp.argmax)
1044
-
1045
- def argmin(self, id, init_params):
1046
- return self.exact_aggregation(jnp.argmin)
1047
-
1048
- # ===========================================================================
1049
- # control flow
1050
- # ===========================================================================
1051
-
1052
- @staticmethod
1053
- def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
1054
- return jnp.where(c > 0.5, a, b), params
1055
-
1056
- def control_if(self, id, init_params):
1057
- return self._jax_wrapped_calc_if_then_else_exact
1058
-
1059
- def control_switch(self, id, init_params):
1060
- def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
- pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
1062
- sample = jnp.take_along_axis(cases, pred, axis=0)
1063
- assert sample.shape[0] == 1
1064
- return sample[0, ...], params
1065
- return _jax_wrapped_calc_switch_exact
1066
-
1067
- # ===========================================================================
1068
- # random variables
1069
- # ===========================================================================
1070
-
1071
- @staticmethod
1072
- def _jax_wrapped_calc_discrete_exact(key, prob, params):
1073
- sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
1074
- return sample, params
1075
-
1076
- def discrete(self, id, init_params):
1077
- return self._jax_wrapped_calc_discrete_exact
1078
-
1079
- @staticmethod
1080
- def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
1081
- return random.bernoulli(key, prob), params
1082
-
1083
- def bernoulli(self, id, init_params):
1084
- return self._jax_wrapped_calc_bernoulli_exact
1085
-
1086
- def poisson(self, id, init_params):
1087
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
1088
- sample = random.poisson(key=key, lam=rate, dtype=self.INT)
1089
- return sample, params
1090
- return _jax_wrapped_calc_poisson_exact
1091
-
1092
- def geometric(self, id, init_params):
1093
- def _jax_wrapped_calc_geometric_exact(key, prob, params):
1094
- sample = random.geometric(key=key, p=prob, dtype=self.INT)
1095
- return sample, params
1096
- return _jax_wrapped_calc_geometric_exact
1097
-
1098
- def binomial(self, id, init_params):
1099
- def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
1349
+ # if prob is non-fluent, always use the exact operation
1350
+ if (not self.traced.cached_is_fluent(arg_trials) and
1351
+ not self.traced.cached_is_fluent(arg_prob)):
1352
+ return super()._jax_binomial(expr, aux)
1353
+
1354
+ aux['overriden'][expr.id] = __class__.__name__
1355
+
1356
+ jax_trials = self._jax(arg_trials, aux)
1357
+ jax_prob = self._jax(arg_prob, aux)
1358
+
1359
+ def _jax_wrapped_distribution_binomial_determinized(fls, nfls, params, key):
1360
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1361
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1100
1362
  trials = jnp.asarray(trials, dtype=self.REAL)
1101
1363
  prob = jnp.asarray(prob, dtype=self.REAL)
1102
- sample = random.binomial(key=key, n=trials, p=prob, dtype=self.REAL)
1103
- sample = jnp.asarray(sample, dtype=self.INT)
1104
- return sample, params
1105
- return _jax_wrapped_calc_binomial_exact
1106
-
1107
- # note: for some reason tfp defines it as number of successes before trials failures
1108
- # I will define it as the number of failures before trials successes
1109
- def negative_binomial(self, id, init_params):
1110
- def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
1364
+ sample = trials * prob
1365
+ out_of_bounds = jnp.logical_not(jnp.all(
1366
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials >= 0)))
1367
+ err = err1 | err2 | (out_of_bounds * ERR)
1368
+ return sample, key, err, params
1369
+ return _jax_wrapped_distribution_binomial_determinized
1370
+
1371
+
1372
+ # ===============================================================================
1373
+ # distribution relaxations - Poisson and NegativeBinomial
1374
+ # ===============================================================================
1375
+
1376
+ class ExponentialPoisson(JaxRDDLCompilerWithGrad):
1377
+
1378
+ def __init__(self, *args,
1379
+ poisson_nbins: int=100,
1380
+ poisson_comparison_weight: float=10.,
1381
+ poisson_min_cdf: float=0.999, **kwargs) -> None:
1382
+ super(ExponentialPoisson, self).__init__(*args, **kwargs)
1383
+ self.poisson_nbins = poisson_nbins
1384
+ self.poisson_comparison_weight = float(poisson_comparison_weight)
1385
+ self.poisson_min_cdf = float(poisson_min_cdf)
1386
+
1387
+ def get_kwargs(self) -> Dict[str, Any]:
1388
+ kwargs = super().get_kwargs()
1389
+ kwargs['poisson_nbins'] = self.poisson_nbins
1390
+ kwargs['poisson_comparison_weight'] = self.poisson_comparison_weight
1391
+ kwargs['poisson_min_cdf'] = self.poisson_min_cdf
1392
+ return kwargs
1393
+
1394
+ def exponential_approx_to_poisson(self, key: random.PRNGKey,
1395
+ rate: jnp.ndarray, w: float) -> jnp.ndarray:
1396
+ exp = random.exponential(
1397
+ key=key, shape=(self.poisson_nbins,) + jnp.shape(rate), dtype=rate.dtype)
1398
+ delta_t = exp / rate[jnp.newaxis, ...]
1399
+ times = jnp.cumsum(delta_t, axis=0)
1400
+ indicator = stable_sigmoid(w * (1. - times))
1401
+ return jnp.sum(indicator, axis=0)
1402
+
1403
+ def branched_approx_to_poisson(self, key: random.PRNGKey,
1404
+ rate: jnp.ndarray, w: float, min_cdf: float) -> jnp.ndarray:
1405
+ cuml_prob = scipy.stats.poisson.cdf(self.poisson_nbins, rate)
1406
+ z = random.normal(key=key, shape=jnp.shape(rate), dtype=rate.dtype)
1407
+ return jnp.where(
1408
+ jax.lax.stop_gradient(cuml_prob >= min_cdf),
1409
+ self.exponential_approx_to_poisson(key, rate, w),
1410
+ rate + jnp.sqrt(rate) * z
1411
+ )
1412
+
1413
+ def _jax_poisson(self, expr, aux):
1414
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
1415
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1416
+ arg_rate, = expr.args
1417
+
1418
+ # if rate is non-fluent, always use the exact operation
1419
+ if not self.traced.cached_is_fluent(arg_rate):
1420
+ return super()._jax_poisson(expr, aux)
1421
+
1422
+ id_ = expr.id
1423
+ aux['params'][id_] = (self.poisson_comparison_weight, self.poisson_min_cdf)
1424
+ aux['overriden'][id_] = __class__.__name__
1425
+
1426
+ jax_rate = self._jax(arg_rate, aux)
1427
+
1428
+ # use the exponential/Poisson process trick for small rate
1429
+ # use the normal approximation for large rate
1430
+ def _jax_wrapped_distribution_poisson_exponential(fls, nfls, params, key):
1431
+ rate, key, err, params = jax_rate(fls, nfls, params, key)
1432
+ key, subkey = random.split(key)
1433
+ sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
1434
+ out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1435
+ err = err | (out_of_bounds * ERR)
1436
+ return sample, key, err, params
1437
+ return _jax_wrapped_distribution_poisson_exponential
1438
+
1439
+ def _jax_negative_binomial(self, expr, aux):
1440
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1441
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1442
+ arg_trials, arg_prob = expr.args
1443
+
1444
+ # if prob and trials is non-fluent, always use the exact operation
1445
+ if (not self.traced.cached_is_fluent(arg_trials) and
1446
+ not self.traced.cached_is_fluent(arg_prob)):
1447
+ return super()._jax_negative_binomial(expr, aux)
1448
+
1449
+ id_ = expr.id
1450
+ aux['params'][id_] = (self.poisson_comparison_weight, self.poisson_min_cdf)
1451
+ aux['overriden'][id_] = __class__.__name__
1452
+
1453
+ jax_trials = self._jax(arg_trials, aux)
1454
+ jax_prob = self._jax(arg_prob, aux)
1455
+
1456
+ # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
1457
+ def _jax_wrapped_distribution_negative_binomial_exponential(fls, nfls, params, key):
1458
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1459
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1460
+ key, subkey = random.split(key)
1111
1461
  trials = jnp.asarray(trials, dtype=self.REAL)
1112
1462
  prob = jnp.asarray(prob, dtype=self.REAL)
1113
- dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
1114
- sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
1115
- return sample, params
1116
- return _jax_wrapped_calc_negative_binomial_exact
1463
+ gamma = random.gamma(key=subkey, a=trials, dtype=self.REAL)
1464
+ rate = ((1.0 - prob) / prob) * gamma
1465
+ sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
1466
+ out_of_bounds = jnp.logical_not(jnp.all(
1467
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
1468
+ err = err1 | err2 | (out_of_bounds * ERR)
1469
+ return sample, key, err, params
1470
+ return _jax_wrapped_distribution_negative_binomial_exponential
1471
+
1472
+
1473
+ class GumbelSoftmaxPoisson(JaxRDDLCompilerWithGrad):
1474
+
1475
+ def __init__(self, *args,
1476
+ poisson_nbins: int=100,
1477
+ poisson_softmax_weight: float=10.,
1478
+ poisson_min_cdf: float=0.999,
1479
+ poisson_eps: float=1e-14, **kwargs) -> None:
1480
+ super(GumbelSoftmaxPoisson, self).__init__(*args, **kwargs)
1481
+ self.poisson_nbins = poisson_nbins
1482
+ self.poisson_softmax_weight = float(poisson_softmax_weight)
1483
+ self.poisson_min_cdf = float(poisson_min_cdf)
1484
+ self.poisson_eps = float(poisson_eps)
1485
+
1486
+ def get_kwargs(self) -> Dict[str, Any]:
1487
+ kwargs = super().get_kwargs()
1488
+ kwargs['poisson_nbins'] = self.poisson_nbins
1489
+ kwargs['poisson_softmax_weight'] = self.poisson_softmax_weight
1490
+ kwargs['poisson_min_cdf'] = self.poisson_min_cdf
1491
+ kwargs['poisson_eps'] = self.poisson_eps
1492
+ return kwargs
1493
+
1494
+ def gumbel_softmax_poisson(self, key: random.PRNGKey,
1495
+ rate: jnp.ndarray, w: float, eps: float) -> jnp.ndarray:
1496
+ ks = jnp.arange(self.poisson_nbins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
1497
+ rate = rate[..., jnp.newaxis]
1498
+ log_prob = ks * jnp.log(rate + eps) - rate - scipy.special.gammaln(ks + 1)
1499
+ g = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=rate.dtype)
1500
+ return SoftmaxArgmax.soft_argmax(g + log_prob, w=w, axes=-1)
1501
+
1502
+ def branched_approx_to_poisson(self, key: random.PRNGKey,
1503
+ rate: jnp.ndarray,
1504
+ w: float, min_cdf: float, eps: float) -> jnp.ndarray:
1505
+ cuml_prob = scipy.stats.poisson.cdf(self.poisson_nbins, rate)
1506
+ z = random.normal(key=key, shape=jnp.shape(rate), dtype=rate.dtype)
1507
+ return jnp.where(
1508
+ jax.lax.stop_gradient(cuml_prob >= min_cdf),
1509
+ self.gumbel_softmax_poisson(key, rate, w, eps),
1510
+ rate + jnp.sqrt(rate) * z
1511
+ )
1512
+
1513
+ def _jax_poisson(self, expr, aux):
1514
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
1515
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1516
+ arg_rate, = expr.args
1517
+
1518
+ # if rate is non-fluent, always use the exact operation
1519
+ if not self.traced.cached_is_fluent(arg_rate):
1520
+ return super()._jax_poisson(expr, aux)
1521
+
1522
+ id_ = expr.id
1523
+ aux['params'][id_] = (self.poisson_softmax_weight, self.poisson_min_cdf, self.poisson_eps)
1524
+ aux['overriden'][id_] = __class__.__name__
1117
1525
 
1118
-
1119
- class FuzzyLogic(Logic):
1120
- '''A class representing fuzzy logic in JAX.'''
1121
-
1122
- def __init__(self, tnorm: TNorm=ProductTNorm(),
1123
- complement: Complement=StandardComplement(),
1124
- comparison: Comparison=SigmoidComparison(),
1125
- sampling: RandomSampling=SoftRandomSampling(),
1126
- rounding: Rounding=SoftRounding(),
1127
- control: ControlFlow=SoftControlFlow(),
1128
- eps: float=1e-15,
1129
- use64bit: bool=False) -> None:
1130
- '''Creates a new fuzzy logic in Jax.
1131
-
1132
- :param tnorm: fuzzy operator for logical AND
1133
- :param complement: fuzzy operator for logical NOT
1134
- :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
1135
- :param sampling: random sampling of non-reparameterizable distributions
1136
- :param rounding: rounding floating values to integers
1137
- :param control: if and switch control structures
1138
- :param eps: small positive float to mitigate underflow
1139
- :param use64bit: whether to perform arithmetic in 64 bit
1140
- '''
1141
- super().__init__(use64bit=use64bit)
1142
- self.tnorm = tnorm
1143
- self.complement = complement
1144
- self.comparison = comparison
1145
- self.sampling = sampling
1146
- self.rounding = rounding
1147
- self.control = control
1148
- self.eps = eps
1149
-
1150
- def __str__(self) -> str:
1151
- return (f'model relaxation:\n'
1152
- f' tnorm ={str(self.tnorm)}\n'
1153
- f' complement ={str(self.complement)}\n'
1154
- f' comparison ={str(self.comparison)}\n'
1155
- f' sampling ={str(self.sampling)}\n'
1156
- f' rounding ={str(self.rounding)}\n'
1157
- f' control ={str(self.control)}\n'
1158
- f' underflow_tol={self.eps}\n'
1159
- f' use_64_bit ={self.use64bit}\n')
1160
-
1161
- def summarize_hyperparameters(self) -> str:
1162
- return self.__str__()
1163
-
1164
- # ===========================================================================
1165
- # logical operators
1166
- # ===========================================================================
1167
-
1168
- def logical_and(self, id, init_params):
1169
- return self.tnorm.norm(id, init_params)
1170
-
1171
- def logical_not(self, id, init_params):
1172
- return self.complement(id, init_params)
1173
-
1174
- def logical_or(self, id, init_params):
1175
- _not1 = self.complement(f'{id}_~1', init_params)
1176
- _not2 = self.complement(f'{id}_~2', init_params)
1177
- _and = self.tnorm.norm(f'{id}_^', init_params)
1178
- _not = self.complement(f'{id}_~', init_params)
1179
-
1180
- def _jax_wrapped_calc_or_approx(x, y, params):
1181
- not_x, params = _not1(x, params)
1182
- not_y, params = _not2(y, params)
1183
- not_x_and_not_y, params = _and(not_x, not_y, params)
1184
- return _not(not_x_and_not_y, params)
1185
- return _jax_wrapped_calc_or_approx
1186
-
1187
- def xor(self, id, init_params):
1188
- _not = self.complement(f'{id}_~', init_params)
1189
- _and1 = self.tnorm.norm(f'{id}_^1', init_params)
1190
- _and2 = self.tnorm.norm(f'{id}_^2', init_params)
1191
- _or = self.logical_or(f'{id}_|', init_params)
1192
-
1193
- def _jax_wrapped_calc_xor_approx(x, y, params):
1194
- x_and_y, params = _and1(x, y, params)
1195
- not_x_and_y, params = _not(x_and_y, params)
1196
- x_or_y, params = _or(x, y, params)
1197
- return _and2(x_or_y, not_x_and_y, params)
1198
- return _jax_wrapped_calc_xor_approx
1199
-
1200
- def implies(self, id, init_params):
1201
- _not = self.complement(f'{id}_~', init_params)
1202
- _or = self.logical_or(f'{id}_|', init_params)
1203
-
1204
- def _jax_wrapped_calc_implies_approx(x, y, params):
1205
- not_x, params = _not(x, params)
1206
- return _or(not_x, y, params)
1207
- return _jax_wrapped_calc_implies_approx
1208
-
1209
- def equiv(self, id, init_params):
1210
- _implies1 = self.implies(f'{id}_=>1', init_params)
1211
- _implies2 = self.implies(f'{id}_=>2', init_params)
1212
- _and = self.tnorm.norm(f'{id}_^', init_params)
1213
-
1214
- def _jax_wrapped_calc_equiv_approx(x, y, params):
1215
- x_implies_y, params = _implies1(x, y, params)
1216
- y_implies_x, params = _implies2(y, x, params)
1217
- return _and(x_implies_y, y_implies_x, params)
1218
- return _jax_wrapped_calc_equiv_approx
1219
-
1220
- def forall(self, id, init_params):
1221
- return self.tnorm.norms(id, init_params)
1222
-
1223
- def exists(self, id, init_params):
1224
- _not1 = self.complement(f'{id}_~1', init_params)
1225
- _not2 = self.complement(f'{id}_~2', init_params)
1226
- _forall = self.forall(f'{id}_forall', init_params)
1227
-
1228
- def _jax_wrapped_calc_exists_approx(x, axis, params):
1229
- not_x, params = _not1(x, params)
1230
- forall_not_x, params = _forall(not_x, axis, params)
1231
- return _not2(forall_not_x, params)
1232
- return _jax_wrapped_calc_exists_approx
1233
-
1234
- # ===========================================================================
1235
- # comparison operators
1236
- # ===========================================================================
1237
-
1238
- def greater_equal(self, id, init_params):
1239
- return self.comparison.greater_equal(id, init_params)
1240
-
1241
- def greater(self, id, init_params):
1242
- return self.comparison.greater(id, init_params)
1243
-
1244
- def less_equal(self, id, init_params):
1245
- _greater_eq = self.greater_equal(id, init_params)
1246
- def _jax_wrapped_calc_leq_approx(x, y, params):
1247
- return _greater_eq(-x, -y, params)
1248
- return _jax_wrapped_calc_leq_approx
1249
-
1250
- def less(self, id, init_params):
1251
- _greater = self.greater(id, init_params)
1252
- def _jax_wrapped_calc_less_approx(x, y, params):
1253
- return _greater(-x, -y, params)
1254
- return _jax_wrapped_calc_less_approx
1255
-
1256
- def equal(self, id, init_params):
1257
- return self.comparison.equal(id, init_params)
1258
-
1259
- def not_equal(self, id, init_params):
1260
- _not = self.complement(f'{id}_~', init_params)
1261
- _equal = self.comparison.equal(f'{id}_==', init_params)
1262
- def _jax_wrapped_calc_neq_approx(x, y, params):
1263
- equal, params = _equal(x, y, params)
1264
- return _not(equal, params)
1265
- return _jax_wrapped_calc_neq_approx
1266
-
1267
- # ===========================================================================
1268
- # special functions
1269
- # ===========================================================================
1270
-
1271
- def sgn(self, id, init_params):
1272
- return self.comparison.sgn(id, init_params)
1273
-
1274
- def floor(self, id, init_params):
1275
- return self.rounding.floor(id, init_params)
1526
+ jax_rate = self._jax(arg_rate, aux)
1276
1527
 
1277
- def round(self, id, init_params):
1278
- return self.rounding.round(id, init_params)
1279
-
1280
- def ceil(self, id, init_params):
1281
- _floor = self.rounding.floor(id, init_params)
1282
- def _jax_wrapped_calc_ceil_approx(x, params):
1283
- neg_floor, params = _floor(-x, params)
1284
- return -neg_floor, params
1285
- return _jax_wrapped_calc_ceil_approx
1286
-
1287
- def div(self, id, init_params):
1288
- _floor = self.rounding.floor(id, init_params)
1289
- def _jax_wrapped_calc_div_approx(x, y, params):
1290
- return _floor(x / y, params)
1291
- return _jax_wrapped_calc_div_approx
1292
-
1293
- def mod(self, id, init_params):
1294
- _div = self.div(id, init_params)
1295
- def _jax_wrapped_calc_mod_approx(x, y, params):
1296
- div, params = _div(x, y, params)
1297
- return x - y * div, params
1298
- return _jax_wrapped_calc_mod_approx
1299
-
1300
- def sqrt(self, id, init_params):
1301
- def _jax_wrapped_calc_sqrt_approx(x, params):
1302
- return jnp.sqrt(x + self.eps), params
1303
- return _jax_wrapped_calc_sqrt_approx
1304
-
1305
- # ===========================================================================
1306
- # indexing
1307
- # ===========================================================================
1308
-
1309
- def argmax(self, id, init_params):
1310
- return self.comparison.argmax(id, init_params)
1311
-
1312
- def argmin(self, id, init_params):
1313
- _argmax = self.argmax(id, init_params)
1314
- def _jax_wrapped_calc_argmin_approx(x, axis, param):
1315
- return _argmax(-x, axis, param)
1316
- return _jax_wrapped_calc_argmin_approx
1317
-
1318
- # ===========================================================================
1319
- # control flow
1320
- # ===========================================================================
1321
-
1322
- def control_if(self, id, init_params):
1323
- return self.control.if_then_else(id, init_params)
1324
-
1325
- def control_switch(self, id, init_params):
1326
- return self.control.switch(id, init_params)
1327
-
1328
- # ===========================================================================
1329
- # random variables
1330
- # ===========================================================================
1331
-
1332
- def discrete(self, id, init_params):
1333
- return self.sampling.discrete(id, init_params, self)
1334
-
1335
- def bernoulli(self, id, init_params):
1336
- return self.sampling.bernoulli(id, init_params, self)
1337
-
1338
- def poisson(self, id, init_params):
1339
- return self.sampling.poisson(id, init_params, self)
1340
-
1341
- def geometric(self, id, init_params):
1342
- return self.sampling.geometric(id, init_params, self)
1343
-
1344
- def binomial(self, id, init_params):
1345
- return self.sampling.binomial(id, init_params, self)
1346
-
1347
- def negative_binomial(self, id, init_params):
1348
- return self.sampling.negative_binomial(id, init_params, self)
1528
+ # use the gumbel-softmax and truncation trick for small rate
1529
+ # use the normal approximation for large rate
1530
+ def _jax_wrapped_distribution_poisson_gumbel_softmax(fls, nfls, params, key):
1531
+ rate, key, err, params = jax_rate(fls, nfls, params, key)
1532
+ key, subkey = random.split(key)
1533
+ sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
1534
+ out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1535
+ err = err | (out_of_bounds * ERR)
1536
+ return sample, key, err, params
1537
+ return _jax_wrapped_distribution_poisson_gumbel_softmax
1538
+
1539
+ def _jax_negative_binomial(self, expr, aux):
1540
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1541
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1542
+ arg_trials, arg_prob = expr.args
1543
+
1544
+ # if prob and trials is non-fluent, always use the exact operation
1545
+ if (not self.traced.cached_is_fluent(arg_trials) and
1546
+ not self.traced.cached_is_fluent(arg_prob)):
1547
+ return super()._jax_negative_binomial(expr, aux)
1548
+
1549
+ id_ = expr.id
1550
+ aux['params'][id_] = (self.poisson_softmax_weight, self.poisson_min_cdf, self.poisson_eps)
1551
+ aux['overriden'][id_] = __class__.__name__
1552
+
1553
+ jax_trials = self._jax(arg_trials, aux)
1554
+ jax_prob = self._jax(arg_prob, aux)
1349
1555
 
1556
+ # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
1557
+ def _jax_wrapped_distribution_negative_binomial_gumbel_softmax(fls, nfls, params, key):
1558
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1559
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1560
+ key, subkey = random.split(key)
1561
+ trials = jnp.asarray(trials, dtype=self.REAL)
1562
+ prob = jnp.asarray(prob, dtype=self.REAL)
1563
+ gamma = random.gamma(key=subkey, a=trials, dtype=self.REAL)
1564
+ rate = ((1.0 - prob) / prob) * gamma
1565
+ sample = self.branched_approx_to_poisson(subkey, rate, *params[id_])
1566
+ out_of_bounds = jnp.logical_not(jnp.all(
1567
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
1568
+ err = err1 | err2 | (out_of_bounds * ERR)
1569
+ return sample, key, err, params
1570
+ return _jax_wrapped_distribution_negative_binomial_gumbel_softmax
1350
1571
 
1351
- # ===========================================================================
1352
- # UNIT TESTS
1353
- #
1354
- # ===========================================================================
1355
-
1356
- logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
1357
- rounding=SoftRounding(10000.0),
1358
- control=SoftControlFlow(10000.0))
1359
-
1360
-
1361
- def _test_logical():
1362
- print('testing logical')
1363
- init_params = {}
1364
- _and = logic.logical_and(0, init_params)
1365
- _not = logic.logical_not(1, init_params)
1366
- _gre = logic.greater(2, init_params)
1367
- _or = logic.logical_or(3, init_params)
1368
- _if = logic.control_if(4, init_params)
1369
- print(init_params)
1370
-
1371
- # https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
1372
- def test_logic(x1, x2, w):
1373
- q1, w = _gre(x1, 0, w)
1374
- q2, w = _gre(x2, 0, w)
1375
- q3, w = _and(q1, q2, w)
1376
- q4, w = _not(q1, w)
1377
- q5, w = _not(q2, w)
1378
- q6, w = _and(q4, q5, w)
1379
- cond, w = _or(q3, q6, w)
1380
- pred, w = _if(cond, +1, -1, w)
1381
- return pred
1382
-
1383
- x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
1384
- x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
1385
- print(test_logic(x1, x2, init_params))
1386
-
1387
-
1388
- def _test_indexing():
1389
- print('testing indexing')
1390
- init_params = {}
1391
- _argmax = logic.argmax(0, init_params)
1392
- _argmin = logic.argmin(1, init_params)
1393
- print(init_params)
1394
-
1395
- def argmaxmin(x, w):
1396
- amax, w = _argmax(x, 0, w)
1397
- amin, w = _argmin(x, 0, w)
1398
- return amax, amin
1399
-
1400
- values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
1401
- amax, amin = argmaxmin(values, init_params)
1402
- print(amax)
1403
- print(amin)
1404
-
1405
-
1406
- def _test_control():
1407
- print('testing control')
1408
- init_params = {}
1409
- _switch = logic.control_switch(0, init_params)
1410
- print(init_params)
1411
-
1412
- pred = jnp.asarray(jnp.linspace(0, 2, 10))
1413
- case1 = jnp.asarray([-10.] * 10)
1414
- case2 = jnp.asarray([1.5] * 10)
1415
- case3 = jnp.asarray([10.] * 10)
1416
- cases = jnp.asarray([case1, case2, case3])
1417
- switch, _ = _switch(pred, cases, init_params)
1418
- print(switch)
1419
-
1420
-
1421
- def _test_random():
1422
- print('testing random')
1423
- key = random.PRNGKey(42)
1424
- init_params = {}
1425
- _bernoulli = logic.bernoulli(0, init_params)
1426
- _discrete = logic.discrete(1, init_params)
1427
- _geometric = logic.geometric(2, init_params)
1428
- print(init_params)
1429
-
1430
- def bern(n, w):
1431
- prob = jnp.asarray([0.3] * n)
1432
- sample, _ = _bernoulli(key, prob, w)
1433
- return sample
1434
-
1435
- samples = bern(50000, init_params)
1436
- print(jnp.mean(samples))
1437
-
1438
- def disc(n, w):
1439
- prob = jnp.asarray([0.1, 0.4, 0.5])
1440
- prob = jnp.tile(prob, (n, 1))
1441
- sample, _ = _discrete(key, prob, w)
1442
- return sample
1443
-
1444
- samples = disc(50000, init_params)
1445
- samples = jnp.round(samples)
1446
- print([jnp.mean(samples == i) for i in range(3)])
1447
-
1448
- def geom(n, w):
1449
- prob = jnp.asarray([0.3] * n)
1450
- sample, _ = _geometric(key, prob, w)
1451
- return sample
1452
-
1453
- samples = geom(50000, init_params)
1454
- print(jnp.mean(samples))
1455
-
1456
1572
 
1457
- def _test_rounding():
1458
- print('testing rounding')
1459
- init_params = {}
1460
- _floor = logic.floor(0, init_params)
1461
- _ceil = logic.ceil(1, init_params)
1462
- _round = logic.round(2, init_params)
1463
- _mod = logic.mod(3, init_params)
1464
- print(init_params)
1465
-
1466
- x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
1467
- print(_floor(x, init_params)[0])
1468
- print(_ceil(x, init_params)[0])
1469
- print(_round(x, init_params)[0])
1470
- print(_mod(x, 2.0, init_params)[0])
1471
-
1472
-
1473
- if __name__ == '__main__':
1474
- _test_logical()
1475
- _test_indexing()
1476
- _test_control()
1477
- _test_random()
1478
- _test_rounding()
1479
-
1573
+ class DeterminizedPoisson(JaxRDDLCompilerWithGrad):
1574
+
1575
+ def __init__(self, *args, **kwargs) -> None:
1576
+ super(DeterminizedPoisson, self).__init__(*args, **kwargs)
1577
+
1578
+ def get_kwargs(self) -> Dict[str, Any]:
1579
+ return super().get_kwargs()
1580
+
1581
+ def _jax_poisson(self, expr, aux):
1582
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_POISSON']
1583
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 1)
1584
+ arg_rate, = expr.args
1585
+
1586
+ # if rate is non-fluent, always use the exact operation
1587
+ if not self.traced.cached_is_fluent(arg_rate):
1588
+ return super()._jax_poisson(expr, aux)
1589
+
1590
+ aux['overriden'][expr.id] = __class__.__name__
1591
+
1592
+ jax_rate = self._jax(arg_rate, aux)
1593
+
1594
+ def _jax_wrapped_distribution_poisson_determinized(fls, nfls, params, key):
1595
+ rate, key, err, params = jax_rate(fls, nfls, params, key)
1596
+ sample = rate
1597
+ out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1598
+ err = err | (out_of_bounds * ERR)
1599
+ return sample, key, err, params
1600
+ return _jax_wrapped_distribution_poisson_determinized
1601
+
1602
+ def _jax_negative_binomial(self, expr, aux):
1603
+ ERR = JaxRDDLCompilerWithGrad.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1604
+ JaxRDDLCompilerWithGrad._check_num_args(expr, 2)
1605
+ arg_trials, arg_prob = expr.args
1606
+
1607
+ # if prob and trials is non-fluent, always use the exact operation
1608
+ if (not self.traced.cached_is_fluent(arg_trials) and
1609
+ not self.traced.cached_is_fluent(arg_prob)):
1610
+ return super()._jax_negative_binomial(expr, aux)
1611
+
1612
+ aux['overriden'][expr.id] = __class__.__name__
1613
+
1614
+ jax_trials = self._jax(arg_trials, aux)
1615
+ jax_prob = self._jax(arg_prob, aux)
1616
+
1617
+ def _jax_wrapped_distribution_negative_binomial_determinized(fls, nfls, params, key):
1618
+ trials, key, err2, params = jax_trials(fls, nfls, params, key)
1619
+ prob, key, err1, params = jax_prob(fls, nfls, params, key)
1620
+ trials = jnp.asarray(trials, dtype=self.REAL)
1621
+ prob = jnp.asarray(prob, dtype=self.REAL)
1622
+ sample = ((1.0 - prob) / prob) * trials
1623
+ out_of_bounds = jnp.logical_not(jnp.all(
1624
+ jnp.logical_and(jnp.logical_and(prob >= 0, prob <= 1), trials > 0)))
1625
+ err = err1 | err2 | (out_of_bounds * ERR)
1626
+ return sample, key, err, params
1627
+ return _jax_wrapped_distribution_negative_binomial_determinized
1628
+
1629
+
1630
+ class DefaultJaxRDDLCompilerWithGrad(SigmoidRelational, SoftmaxArgmax,
1631
+ ProductNormLogical,
1632
+ SafeSqrt, SoftFloor, SoftRound,
1633
+ LinearIfElse, SoftmaxSwitch,
1634
+ ReparameterizedGeometric,
1635
+ ReparameterizedSigmoidBernoulli,
1636
+ GumbelSoftmaxDiscrete, GumbelSoftmaxBinomial,
1637
+ ExponentialPoisson):
1638
+
1639
+ def __init__(self, *args, **kwargs) -> None:
1640
+ super(DefaultJaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
1641
+
1642
+ def get_kwargs(self) -> Dict[str, Any]:
1643
+ kwargs = {}
1644
+ for base in type(self).__bases__:
1645
+ if base.__name__ != 'object':
1646
+ kwargs = {**kwargs, **base.get_kwargs(self)}
1647
+ return kwargs