pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ __version__ = '0.4'
@@ -1,23 +1,11 @@
1
1
  from functools import partial
2
+ import traceback
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
2
5
  import jax
3
6
  import jax.numpy as jnp
4
7
  import jax.random as random
5
8
  import jax.scipy as scipy
6
- import traceback
7
- from typing import Any, Callable, Dict, List, Optional
8
-
9
- from pyRDDLGym.core.debug.exception import raise_warning
10
-
11
- # more robust approach - if user does not have this or broken try to continue
12
- try:
13
- from tensorflow_probability.substrates import jax as tfp
14
- except Exception:
15
- raise_warning('Failed to import tensorflow-probability: '
16
- 'compilation of some complex distributions '
17
- '(Binomial, Negative-Binomial, Multinomial) will fail. '
18
- 'Please ensure this package is installed correctly.', 'red')
19
- traceback.print_exc()
20
- tfp = None
21
9
 
22
10
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
23
11
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -26,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
26
14
  from pyRDDLGym.core.constraints import RDDLConstraints
27
15
  from pyRDDLGym.core.debug.exception import (
28
16
  print_stack_trace,
17
+ raise_warning,
29
18
  RDDLInvalidNumberOfArgumentsError,
30
19
  RDDLNotImplementedError
31
20
  )
32
21
  from pyRDDLGym.core.debug.logger import Logger
33
22
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
34
23
 
24
+ # more robust approach - if user does not have this or broken try to continue
25
+ try:
26
+ from tensorflow_probability.substrates import jax as tfp
27
+ except Exception:
28
+ raise_warning('Failed to import tensorflow-probability: '
29
+ 'compilation of some complex distributions '
30
+ '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
31
+ traceback.print_exc()
32
+ tfp = None
33
+
35
34
 
36
35
  # ===========================================================================
37
36
  # EXACT RDDL TO JAX COMPILATION RULES
@@ -88,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
88
87
  def _function_if_exact_named():
89
88
 
90
89
  def _jax_wrapped_if_exact(c, a, b, param):
91
- return jnp.where(c, a, b)
90
+ return jnp.where(c > 0.5, a, b)
92
91
 
93
92
  return _jax_wrapped_if_exact
94
93
 
@@ -115,16 +114,27 @@ def _function_bernoulli_exact_named():
115
114
  def _function_discrete_exact_named():
116
115
 
117
116
  def _jax_wrapped_discrete_exact(key, prob, param):
118
- logits = jnp.log(prob)
119
- sample = random.categorical(key=key, logits=logits, axis=-1)
120
- out_of_bounds = jnp.logical_not(jnp.logical_and(
121
- jnp.all(prob >= 0),
122
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
123
- return sample, out_of_bounds
117
+ return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
124
118
 
125
119
  return _jax_wrapped_discrete_exact
126
120
 
127
121
 
122
+ def _function_poisson_exact_named():
123
+
124
+ def _jax_wrapped_poisson_exact(key, rate, param):
125
+ return random.poisson(key=key, lam=rate, dtype=jnp.int64)
126
+
127
+ return _jax_wrapped_poisson_exact
128
+
129
+
130
+ def _function_geometric_exact_named():
131
+
132
+ def _jax_wrapped_geometric_exact(key, prob, param):
133
+ return random.geometric(key=key, p=prob, dtype=jnp.int64)
134
+
135
+ return _jax_wrapped_geometric_exact
136
+
137
+
128
138
  class JaxRDDLCompiler:
129
139
  '''Compiles a RDDL AST representation into an equivalent JAX representation.
130
140
  All operations are identical to their numpy equivalents.
@@ -211,12 +221,12 @@ class JaxRDDLCompiler:
211
221
  }
212
222
 
213
223
  EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
214
-
215
224
  EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
216
225
 
217
226
  EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
218
-
219
227
  EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
228
+ EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
229
+ EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
220
230
 
221
231
  def __init__(self, rddl: RDDLLiftedModel,
222
232
  allow_synchronous_state: bool=True,
@@ -290,6 +300,8 @@ class JaxRDDLCompiler:
290
300
  self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
291
301
  self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
292
302
  self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
303
+ self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
304
+ self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
293
305
 
294
306
  # ===========================================================================
295
307
  # main compilation subroutines
@@ -997,13 +1009,14 @@ class JaxRDDLCompiler:
997
1009
  jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
998
1010
  return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
999
1011
 
1000
- elif n == 2:
1001
- lhs, rhs = args
1002
- jax_lhs = self._jax(lhs, info)
1003
- jax_rhs = self._jax(rhs, info)
1012
+ elif n == 2 or (n >= 2 and op in {'*', '+'}):
1013
+ jax_exprs = [self._jax(arg, info) for arg in args]
1004
1014
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1005
- return self._jax_binary(
1006
- jax_lhs, jax_rhs, jax_op, jax_param, at_least_int=True)
1015
+ result = jax_exprs[0]
1016
+ for jax_rhs in jax_exprs[1:]:
1017
+ result = self._jax_binary(
1018
+ result, jax_rhs, jax_op, jax_param, at_least_int=True)
1019
+ return result
1007
1020
 
1008
1021
  JaxRDDLCompiler._check_num_args(expr, 2)
1009
1022
 
@@ -1047,13 +1060,14 @@ class JaxRDDLCompiler:
1047
1060
  jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
1048
1061
  return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
1049
1062
 
1050
- elif n == 2:
1051
- lhs, rhs = args
1052
- jax_lhs = self._jax(lhs, info)
1053
- jax_rhs = self._jax(rhs, info)
1063
+ elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
1064
+ jax_exprs = [self._jax(arg, info) for arg in args]
1054
1065
  jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
1055
- return self._jax_binary(
1056
- jax_lhs, jax_rhs, jax_op, jax_param, check_dtype=bool)
1066
+ result = jax_exprs[0]
1067
+ for jax_rhs in jax_exprs[1:]:
1068
+ result = self._jax_binary(
1069
+ result, jax_rhs, jax_op, jax_param, check_dtype=bool)
1070
+ return result
1057
1071
 
1058
1072
  JaxRDDLCompiler._check_num_args(expr, 2)
1059
1073
 
@@ -1166,16 +1180,17 @@ class JaxRDDLCompiler:
1166
1180
  return _jax_wrapped_if_then_else
1167
1181
 
1168
1182
  def _jax_switch(self, expr, info):
1183
+ pred, *_ = expr.args
1169
1184
 
1170
- # if expression is non-fluent, always use the exact operation
1171
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
1185
+ # if predicate is non-fluent, always use the exact operation
1186
+ # case conditions are currently only literals so they are non-fluent
1187
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1172
1188
  switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
1173
1189
  else:
1174
1190
  switch_op = self.SWITCH_HELPER
1175
1191
  jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
1176
1192
 
1177
1193
  # recursively compile predicate
1178
- pred, *_ = expr.args
1179
1194
  jax_pred = self._jax(pred, info)
1180
1195
 
1181
1196
  # recursively compile cases
@@ -1427,15 +1442,24 @@ class JaxRDDLCompiler:
1427
1442
  def _jax_poisson(self, expr, info):
1428
1443
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
1429
1444
  JaxRDDLCompiler._check_num_args(expr, 1)
1430
-
1431
1445
  arg_rate, = expr.args
1446
+
1447
+ # if rate is non-fluent, always use the exact operation
1448
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1449
+ poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1450
+ else:
1451
+ poisson_op = self.POISSON_HELPER
1452
+ jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
1453
+
1454
+ # recursively compile arguments
1432
1455
  jax_rate = self._jax(arg_rate, info)
1433
1456
 
1434
1457
  # uses the implicit JAX subroutine
1435
1458
  def _jax_wrapped_distribution_poisson(x, params, key):
1436
1459
  rate, key, err = jax_rate(x, params, key)
1437
1460
  key, subkey = random.split(key)
1438
- sample = random.poisson(key=subkey, lam=rate, dtype=self.INT)
1461
+ param = params.get(jax_param, None)
1462
+ sample = jax_poisson(subkey, rate, param).astype(self.INT)
1439
1463
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1440
1464
  err |= (out_of_bounds * ERR)
1441
1465
  return sample, key, err
@@ -1536,33 +1560,25 @@ class JaxRDDLCompiler:
1536
1560
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
1537
1561
  JaxRDDLCompiler._check_num_args(expr, 1)
1538
1562
  arg_prob, = expr.args
1563
+
1564
+ # if prob is non-fluent, always use the exact operation
1565
+ if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1566
+ geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1567
+ else:
1568
+ geom_op = self.GEOMETRIC_HELPER
1569
+ jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
1570
+
1571
+ # recursively compile arguments
1539
1572
  jax_prob = self._jax(arg_prob, info)
1540
1573
 
1541
- if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1542
-
1543
- # prob is non-fluent: do not reparameterize
1544
- def _jax_wrapped_distribution_geometric(x, params, key):
1545
- prob, key, err = jax_prob(x, params, key)
1546
- key, subkey = random.split(key)
1547
- sample = random.geometric(key=subkey, p=prob, dtype=self.INT)
1548
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1549
- err |= (out_of_bounds * ERR)
1550
- return sample, key, err
1551
-
1552
- else:
1553
- floor_op, jax_param = self._unwrap(
1554
- self.KNOWN_UNARY['floor'], expr.id, info)
1555
-
1556
- # reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
1557
- def _jax_wrapped_distribution_geometric(x, params, key):
1558
- prob, key, err = jax_prob(x, params, key)
1559
- key, subkey = random.split(key)
1560
- U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
1561
- param = params.get(jax_param, None)
1562
- sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
1563
- out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1564
- err |= (out_of_bounds * ERR)
1565
- return sample, key, err
1574
+ def _jax_wrapped_distribution_geometric(x, params, key):
1575
+ prob, key, err = jax_prob(x, params, key)
1576
+ key, subkey = random.split(key)
1577
+ param = params.get(jax_param, None)
1578
+ sample = jax_geom(subkey, prob, param).astype(self.INT)
1579
+ out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1580
+ err |= (out_of_bounds * ERR)
1581
+ return sample, key, err
1566
1582
 
1567
1583
  return _jax_wrapped_distribution_geometric
1568
1584
 
@@ -1771,7 +1787,10 @@ class JaxRDDLCompiler:
1771
1787
  # dispatch to sampling subroutine
1772
1788
  key, subkey = random.split(key)
1773
1789
  param = params.get(jax_param, None)
1774
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1790
+ sample = jax_discrete(subkey, prob, param)
1791
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1792
+ jnp.all(prob >= 0),
1793
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1775
1794
  error |= (out_of_bounds * ERR)
1776
1795
  return sample, key, error
1777
1796
 
@@ -1804,7 +1823,10 @@ class JaxRDDLCompiler:
1804
1823
  # dispatch to sampling subroutine
1805
1824
  key, subkey = random.split(key)
1806
1825
  param = params.get(jax_param, None)
1807
- sample, out_of_bounds = jax_discrete(subkey, prob, param)
1826
+ sample = jax_discrete(subkey, prob, param)
1827
+ out_of_bounds = jnp.logical_not(jnp.logical_and(
1828
+ jnp.all(prob >= 0),
1829
+ jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
1808
1830
  error |= (out_of_bounds * ERR)
1809
1831
  return sample, key, error
1810
1832
 
@@ -1,11 +1,19 @@
1
+ from typing import Optional, Set
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import jax.random as random
4
- from typing import Optional, Set
5
6
 
6
7
  from pyRDDLGym.core.debug.exception import raise_warning
7
8
 
8
9
 
10
+ # ===========================================================================
11
+ # LOGICAL COMPLEMENT
12
+ # - abstract class
13
+ # - standard complement
14
+ #
15
+ # ===========================================================================
16
+
9
17
  class Complement:
10
18
  '''Base class for approximate logical complement operations.'''
11
19
 
@@ -20,6 +28,13 @@ class StandardComplement(Complement):
20
28
  return 1.0 - x
21
29
 
22
30
 
31
+ # ===========================================================================
32
+ # RELATIONAL OPERATIONS
33
+ # - abstract class
34
+ # - sigmoid comparison
35
+ #
36
+ # ===========================================================================
37
+
23
38
  class Comparison:
24
39
  '''Base class for approximate comparison operations.'''
25
40
 
@@ -44,7 +59,17 @@ class SigmoidComparison(Comparison):
44
59
 
45
60
  def equal(self, x, y, param):
46
61
  return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
47
-
62
+
63
+
64
+ # ===========================================================================
65
+ # TNORMS
66
+ # - abstract tnorm
67
+ # - product tnorm
68
+ # - Godel tnorm
69
+ # - Lukasiewicz tnorm
70
+ # - Yager(p) tnorm
71
+ #
72
+ # ===========================================================================
48
73
 
49
74
  class TNorm:
50
75
  '''Base class for fuzzy differentiable t-norms.'''
@@ -86,8 +111,133 @@ class LukasiewiczTNorm(TNorm):
86
111
 
87
112
  def norms(self, x, axis):
88
113
  return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
114
+
115
+
116
+ class YagerTNorm(TNorm):
117
+ '''Yager t-norm given by the expression
118
+ (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
119
+
120
+ def __init__(self, p=2.0):
121
+ self.p = p
122
+
123
+ def norm(self, x, y):
124
+ base_x = jax.nn.relu(1.0 - x)
125
+ base_y = jax.nn.relu(1.0 - y)
126
+ arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
127
+ return jax.nn.relu(1.0 - arg)
128
+
129
+ def norms(self, x, axis):
130
+ base = jax.nn.relu(1.0 - x)
131
+ arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
132
+ return jax.nn.relu(1.0 - arg)
133
+
134
+
135
+ # ===========================================================================
136
+ # RANDOM SAMPLING
137
+ # - abstract sampler
138
+ # - Gumbel-softmax sampler
139
+ # - determinization
140
+ #
141
+ # ===========================================================================
142
+
143
+ class RandomSampling:
144
+ '''An abstract class that describes how discrete and non-reparameterizable
145
+ random variables are sampled.'''
146
+
147
+ def discrete(self, logic):
148
+ raise NotImplementedError
149
+
150
+ def bernoulli(self, logic):
151
+ jax_discrete, jax_param = self.discrete(logic)
152
+
153
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
154
+ prob = jnp.stack([1.0 - prob, prob], axis=-1)
155
+ sample = jax_discrete(key, prob, param)
156
+ return sample
157
+
158
+ return _jax_wrapped_calc_bernoulli_approx, jax_param
159
+
160
+ def poisson(self, logic):
161
+
162
+ def _jax_wrapped_calc_poisson_exact(key, rate, param):
163
+ return random.poisson(key=key, lam=rate, dtype=logic.INT)
164
+
165
+ return _jax_wrapped_calc_poisson_exact, None
166
+
167
+ def geometric(self, logic):
168
+ if logic.verbose:
169
+ raise_warning('Using the replacement rule: '
170
+ 'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
171
+
172
+ jax_floor, jax_param = logic.floor()
173
+
174
+ def _jax_wrapped_calc_geometric_approx(key, prob, param):
175
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
176
+ sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
177
+ return sample
178
+
179
+ return _jax_wrapped_calc_geometric_approx, jax_param
180
+
181
+
182
+ class GumbelSoftmax(RandomSampling):
183
+ '''Random sampling of discrete variables using Gumbel-softmax trick.'''
89
184
 
185
+ def discrete(self, logic):
186
+ if logic.verbose:
187
+ raise_warning('Using the replacement rule: '
188
+ 'Discrete(p) --> Gumbel-softmax(p)')
189
+
190
+ jax_argmax, jax_param = logic.argmax()
191
+
192
+ def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
193
+ Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
194
+ sample = Gumbel01 + jnp.log(prob + logic.eps)
195
+ sample = jax_argmax(sample, axis=-1, param=param)
196
+ return sample
197
+
198
+ return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
199
+
200
+
201
+ class Determinization(RandomSampling):
202
+ '''Random sampling of variables using their deterministic mean estimate.'''
90
203
 
204
+ def discrete(self, logic):
205
+ if logic.verbose:
206
+ raise_warning('Using the replacement rule: '
207
+ 'Discrete(p) --> sum(i * p[i])')
208
+
209
+ def _jax_wrapped_calc_discrete_determinized(key, prob, param):
210
+ literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
211
+ sample = jnp.sum(literals * prob, axis=-1)
212
+ return sample
213
+
214
+ return _jax_wrapped_calc_discrete_determinized, None
215
+
216
+ def poisson(self, logic):
217
+ if logic.verbose:
218
+ raise_warning('Using the replacement rule: Poisson(rate) --> rate')
219
+
220
+ def _jax_wrapped_calc_poisson_determinized(key, rate, param):
221
+ return rate
222
+
223
+ return _jax_wrapped_calc_poisson_determinized, None
224
+
225
+ def geometric(self, logic):
226
+ if logic.verbose:
227
+ raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
228
+
229
+ def _jax_wrapped_calc_geometric_determinized(key, prob, param):
230
+ sample = 1.0 / prob
231
+ return sample
232
+
233
+ return _jax_wrapped_calc_geometric_determinized, None
234
+
235
+
236
+ # ===========================================================================
237
+ # FUZZY LOGIC
238
+ #
239
+ # ===========================================================================
240
+
91
241
  class FuzzyLogic:
92
242
  '''A class representing fuzzy logic in JAX.
93
243
 
@@ -98,9 +248,10 @@ class FuzzyLogic:
98
248
  def __init__(self, tnorm: TNorm=ProductTNorm(),
99
249
  complement: Complement=StandardComplement(),
100
250
  comparison: Comparison=SigmoidComparison(),
251
+ sampling: RandomSampling=GumbelSoftmax(),
101
252
  weight: float=10.0,
102
253
  debias: Optional[Set[str]]=None,
103
- eps: float=1e-10,
254
+ eps: float=1e-15,
104
255
  verbose: bool=False,
105
256
  use64bit: bool=False) -> None:
106
257
  '''Creates a new fuzzy logic in Jax.
@@ -108,8 +259,8 @@ class FuzzyLogic:
108
259
  :param tnorm: fuzzy operator for logical AND
109
260
  :param complement: fuzzy operator for logical NOT
110
261
  :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
262
+ :param sampling: random sampling of non-reparameterizable distributions
111
263
  :param weight: a sharpness parameter for sigmoid and softmax activations
112
- :param error: an error parameter (e.g. floor) (smaller means better accuracy)
113
264
  :param debias: which functions to de-bias approximate on forward pass
114
265
  :param eps: small positive float to mitigate underflow
115
266
  :param verbose: whether to dump replacements and other info to console
@@ -118,6 +269,7 @@ class FuzzyLogic:
118
269
  self.tnorm = tnorm
119
270
  self.complement = complement
120
271
  self.comparison = comparison
272
+ self.sampling = sampling
121
273
  self.weight = float(weight)
122
274
  if debias is None:
123
275
  debias = set()
@@ -142,10 +294,11 @@ class FuzzyLogic:
142
294
  f' tnorm ={type(self.tnorm).__name__}\n'
143
295
  f' complement ={type(self.complement).__name__}\n'
144
296
  f' comparison ={type(self.comparison).__name__}\n'
297
+ f' sampling ={type(self.sampling).__name__}\n'
145
298
  f' sigmoid_weight={self.weight}\n'
146
299
  f' cpfs_to_debias={self.debias}\n'
147
300
  f' underflow_tol ={self.eps}\n'
148
- f' use64bit ={self.use64bit}')
301
+ f' use_64_bit ={self.use64bit}')
149
302
 
150
303
  # ===========================================================================
151
304
  # logical operators
@@ -419,7 +572,7 @@ class FuzzyLogic:
419
572
  # ===========================================================================
420
573
 
421
574
  @staticmethod
422
- def _literals(shape, axis):
575
+ def enumerate_literals(shape, axis):
423
576
  literals = jnp.arange(shape[axis])
424
577
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
425
578
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -434,7 +587,7 @@ class FuzzyLogic:
434
587
  debias = 'argmax' in self.debias
435
588
 
436
589
  def _jax_wrapped_calc_argmax_approx(x, axis, param):
437
- literals = FuzzyLogic._literals(x.shape, axis=axis)
590
+ literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
438
591
  soft_max = jax.nn.softmax(param * x, axis=axis)
439
592
  sample = jnp.sum(literals * soft_max, axis=axis)
440
593
  if debias:
@@ -468,7 +621,7 @@ class FuzzyLogic:
468
621
  def _jax_wrapped_calc_if_approx(c, a, b, param):
469
622
  sample = c * a + (1.0 - c) * b
470
623
  if debias:
471
- hard_sample = jnp.select([c, ~c], [a, b])
624
+ hard_sample = jnp.where(c > 0.5, a, b)
472
625
  sample += jax.lax.stop_gradient(hard_sample - sample)
473
626
  return sample
474
627
 
@@ -483,7 +636,7 @@ class FuzzyLogic:
483
636
  debias = 'switch' in self.debias
484
637
 
485
638
  def _jax_wrapped_calc_switch_approx(pred, cases, param):
486
- literals = FuzzyLogic._literals(cases.shape, axis=0)
639
+ literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
487
640
  pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
488
641
  proximity = -jnp.abs(pred - literals)
489
642
  soft_case = jax.nn.softmax(param * proximity, axis=0)
@@ -502,44 +655,24 @@ class FuzzyLogic:
502
655
  # random variables
503
656
  # ===========================================================================
504
657
 
505
- def _gumbel_softmax(self, key, prob):
506
- Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=self.REAL)
507
- sample = Gumbel01 + jnp.log(prob + self.eps)
508
- return sample
509
-
658
+ def discrete(self):
659
+ return self.sampling.discrete(self)
660
+
510
661
  def bernoulli(self):
511
- if self.verbose:
512
- raise_warning('Using the replacement rule: '
513
- 'Bernoulli(p) --> Gumbel-softmax(p)')
514
-
515
- jax_gs = self._gumbel_softmax
516
- jax_argmax, jax_param = self.argmax()
517
-
518
- def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
519
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
520
- sample = jax_gs(key, prob)
521
- sample = jax_argmax(sample, axis=-1, param=param)
522
- return sample
523
-
524
- return _jax_wrapped_calc_bernoulli_approx, jax_param
662
+ return self.sampling.bernoulli(self)
525
663
 
526
- def discrete(self):
527
- if self.verbose:
528
- raise_warning('Using the replacement rule: '
529
- 'Discrete(p) --> Gumbel-softmax(p)')
530
-
531
- jax_gs = self._gumbel_softmax
532
- jax_argmax, jax_param = self.argmax()
533
-
534
- def _jax_wrapped_calc_discrete_approx(key, prob, param):
535
- sample = jax_gs(key, prob)
536
- sample = jax_argmax(sample, axis=-1, param=param)
537
- return sample
538
-
539
- return _jax_wrapped_calc_discrete_approx, jax_param
540
-
664
+ def poisson(self):
665
+ return self.sampling.poisson(self)
666
+
667
+ def geometric(self):
668
+ return self.sampling.geometric(self)
669
+
541
670
 
671
+ # ===========================================================================
542
672
  # UNIT TESTS
673
+ #
674
+ # ===========================================================================
675
+
543
676
  logic = FuzzyLogic()
544
677
  w = 100.0
545
678
 
@@ -598,13 +731,14 @@ def _test_random():
598
731
  key = random.PRNGKey(42)
599
732
  _bernoulli, _ = logic.bernoulli()
600
733
  _discrete, _ = logic.discrete()
734
+ _geometric, _ = logic.geometric()
601
735
 
602
736
  def bern(n):
603
737
  prob = jnp.asarray([0.3] * n)
604
738
  sample = _bernoulli(key, prob, w)
605
739
  return sample
606
740
 
607
- samples = bern(5000)
741
+ samples = bern(50000)
608
742
  print(jnp.mean(samples))
609
743
 
610
744
  def disc(n):
@@ -613,10 +747,18 @@ def _test_random():
613
747
  sample = _discrete(key, prob, w)
614
748
  return sample
615
749
 
616
- samples = disc(5000)
750
+ samples = disc(50000)
617
751
  samples = jnp.round(samples)
618
752
  print([jnp.mean(samples == i) for i in range(3)])
619
-
753
+
754
+ def geom(n):
755
+ prob = jnp.asarray([0.3] * n)
756
+ sample = _geometric(key, prob, w)
757
+ return sample
758
+
759
+ samples = geom(50000)
760
+ print(jnp.mean(samples))
761
+
620
762
 
621
763
  def _test_rounding():
622
764
  print('testing rounding')