pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +90 -68
- pyRDDLGym_jax/core/logic.py +188 -46
- pyRDDLGym_jax/core/planner.py +411 -195
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +13 -10
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_slp.cfg +1 -1
- pyRDDLGym_jax/examples/run_gym.py +2 -5
- pyRDDLGym_jax/examples/run_plan.py +6 -8
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +5 -6
- pyRDDLGym_jax-0.4.dist-info/METADATA +276 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/RECORD +20 -22
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/Pong_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.2.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.4'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1,23 +1,11 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
|
|
2
5
|
import jax
|
|
3
6
|
import jax.numpy as jnp
|
|
4
7
|
import jax.random as random
|
|
5
8
|
import jax.scipy as scipy
|
|
6
|
-
import traceback
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
8
|
-
|
|
9
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
10
|
-
|
|
11
|
-
# more robust approach - if user does not have this or broken try to continue
|
|
12
|
-
try:
|
|
13
|
-
from tensorflow_probability.substrates import jax as tfp
|
|
14
|
-
except Exception:
|
|
15
|
-
raise_warning('Failed to import tensorflow-probability: '
|
|
16
|
-
'compilation of some complex distributions '
|
|
17
|
-
'(Binomial, Negative-Binomial, Multinomial) will fail. '
|
|
18
|
-
'Please ensure this package is installed correctly.', 'red')
|
|
19
|
-
traceback.print_exc()
|
|
20
|
-
tfp = None
|
|
21
9
|
|
|
22
10
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
23
11
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -26,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
|
|
|
26
14
|
from pyRDDLGym.core.constraints import RDDLConstraints
|
|
27
15
|
from pyRDDLGym.core.debug.exception import (
|
|
28
16
|
print_stack_trace,
|
|
17
|
+
raise_warning,
|
|
29
18
|
RDDLInvalidNumberOfArgumentsError,
|
|
30
19
|
RDDLNotImplementedError
|
|
31
20
|
)
|
|
32
21
|
from pyRDDLGym.core.debug.logger import Logger
|
|
33
22
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
34
23
|
|
|
24
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
25
|
+
try:
|
|
26
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
27
|
+
except Exception:
|
|
28
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
29
|
+
'compilation of some complex distributions '
|
|
30
|
+
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
31
|
+
traceback.print_exc()
|
|
32
|
+
tfp = None
|
|
33
|
+
|
|
35
34
|
|
|
36
35
|
# ===========================================================================
|
|
37
36
|
# EXACT RDDL TO JAX COMPILATION RULES
|
|
@@ -88,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
|
|
|
88
87
|
def _function_if_exact_named():
|
|
89
88
|
|
|
90
89
|
def _jax_wrapped_if_exact(c, a, b, param):
|
|
91
|
-
return jnp.where(c, a, b)
|
|
90
|
+
return jnp.where(c > 0.5, a, b)
|
|
92
91
|
|
|
93
92
|
return _jax_wrapped_if_exact
|
|
94
93
|
|
|
@@ -115,16 +114,27 @@ def _function_bernoulli_exact_named():
|
|
|
115
114
|
def _function_discrete_exact_named():
|
|
116
115
|
|
|
117
116
|
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
118
|
-
|
|
119
|
-
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
120
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
121
|
-
jnp.all(prob >= 0),
|
|
122
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
123
|
-
return sample, out_of_bounds
|
|
117
|
+
return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
124
118
|
|
|
125
119
|
return _jax_wrapped_discrete_exact
|
|
126
120
|
|
|
127
121
|
|
|
122
|
+
def _function_poisson_exact_named():
|
|
123
|
+
|
|
124
|
+
def _jax_wrapped_poisson_exact(key, rate, param):
|
|
125
|
+
return random.poisson(key=key, lam=rate, dtype=jnp.int64)
|
|
126
|
+
|
|
127
|
+
return _jax_wrapped_poisson_exact
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _function_geometric_exact_named():
|
|
131
|
+
|
|
132
|
+
def _jax_wrapped_geometric_exact(key, prob, param):
|
|
133
|
+
return random.geometric(key=key, p=prob, dtype=jnp.int64)
|
|
134
|
+
|
|
135
|
+
return _jax_wrapped_geometric_exact
|
|
136
|
+
|
|
137
|
+
|
|
128
138
|
class JaxRDDLCompiler:
|
|
129
139
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
130
140
|
All operations are identical to their numpy equivalents.
|
|
@@ -211,12 +221,12 @@ class JaxRDDLCompiler:
|
|
|
211
221
|
}
|
|
212
222
|
|
|
213
223
|
EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
|
|
214
|
-
|
|
215
224
|
EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
|
|
216
225
|
|
|
217
226
|
EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
|
|
218
|
-
|
|
219
227
|
EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
|
|
228
|
+
EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
|
|
229
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
|
|
220
230
|
|
|
221
231
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
222
232
|
allow_synchronous_state: bool=True,
|
|
@@ -290,6 +300,8 @@ class JaxRDDLCompiler:
|
|
|
290
300
|
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
291
301
|
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
292
302
|
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
303
|
+
self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
304
|
+
self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
293
305
|
|
|
294
306
|
# ===========================================================================
|
|
295
307
|
# main compilation subroutines
|
|
@@ -997,13 +1009,14 @@ class JaxRDDLCompiler:
|
|
|
997
1009
|
jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
|
|
998
1010
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
999
1011
|
|
|
1000
|
-
elif n == 2:
|
|
1001
|
-
|
|
1002
|
-
jax_lhs = self._jax(lhs, info)
|
|
1003
|
-
jax_rhs = self._jax(rhs, info)
|
|
1012
|
+
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
1013
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1004
1014
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1005
|
-
|
|
1006
|
-
|
|
1015
|
+
result = jax_exprs[0]
|
|
1016
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1017
|
+
result = self._jax_binary(
|
|
1018
|
+
result, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
1019
|
+
return result
|
|
1007
1020
|
|
|
1008
1021
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1009
1022
|
|
|
@@ -1047,13 +1060,14 @@ class JaxRDDLCompiler:
|
|
|
1047
1060
|
jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
|
|
1048
1061
|
return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
|
|
1049
1062
|
|
|
1050
|
-
elif n == 2:
|
|
1051
|
-
|
|
1052
|
-
jax_lhs = self._jax(lhs, info)
|
|
1053
|
-
jax_rhs = self._jax(rhs, info)
|
|
1063
|
+
elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
|
|
1064
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1054
1065
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1055
|
-
|
|
1056
|
-
|
|
1066
|
+
result = jax_exprs[0]
|
|
1067
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1068
|
+
result = self._jax_binary(
|
|
1069
|
+
result, jax_rhs, jax_op, jax_param, check_dtype=bool)
|
|
1070
|
+
return result
|
|
1057
1071
|
|
|
1058
1072
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1059
1073
|
|
|
@@ -1166,16 +1180,17 @@ class JaxRDDLCompiler:
|
|
|
1166
1180
|
return _jax_wrapped_if_then_else
|
|
1167
1181
|
|
|
1168
1182
|
def _jax_switch(self, expr, info):
|
|
1183
|
+
pred, *_ = expr.args
|
|
1169
1184
|
|
|
1170
|
-
# if
|
|
1171
|
-
|
|
1185
|
+
# if predicate is non-fluent, always use the exact operation
|
|
1186
|
+
# case conditions are currently only literals so they are non-fluent
|
|
1187
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1172
1188
|
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1173
1189
|
else:
|
|
1174
1190
|
switch_op = self.SWITCH_HELPER
|
|
1175
1191
|
jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
|
|
1176
1192
|
|
|
1177
1193
|
# recursively compile predicate
|
|
1178
|
-
pred, *_ = expr.args
|
|
1179
1194
|
jax_pred = self._jax(pred, info)
|
|
1180
1195
|
|
|
1181
1196
|
# recursively compile cases
|
|
@@ -1427,15 +1442,24 @@ class JaxRDDLCompiler:
|
|
|
1427
1442
|
def _jax_poisson(self, expr, info):
|
|
1428
1443
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1429
1444
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1430
|
-
|
|
1431
1445
|
arg_rate, = expr.args
|
|
1446
|
+
|
|
1447
|
+
# if rate is non-fluent, always use the exact operation
|
|
1448
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1449
|
+
poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
1450
|
+
else:
|
|
1451
|
+
poisson_op = self.POISSON_HELPER
|
|
1452
|
+
jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
|
|
1453
|
+
|
|
1454
|
+
# recursively compile arguments
|
|
1432
1455
|
jax_rate = self._jax(arg_rate, info)
|
|
1433
1456
|
|
|
1434
1457
|
# uses the implicit JAX subroutine
|
|
1435
1458
|
def _jax_wrapped_distribution_poisson(x, params, key):
|
|
1436
1459
|
rate, key, err = jax_rate(x, params, key)
|
|
1437
1460
|
key, subkey = random.split(key)
|
|
1438
|
-
|
|
1461
|
+
param = params.get(jax_param, None)
|
|
1462
|
+
sample = jax_poisson(subkey, rate, param).astype(self.INT)
|
|
1439
1463
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1440
1464
|
err |= (out_of_bounds * ERR)
|
|
1441
1465
|
return sample, key, err
|
|
@@ -1536,33 +1560,25 @@ class JaxRDDLCompiler:
|
|
|
1536
1560
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1537
1561
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1538
1562
|
arg_prob, = expr.args
|
|
1563
|
+
|
|
1564
|
+
# if prob is non-fluent, always use the exact operation
|
|
1565
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1566
|
+
geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
1567
|
+
else:
|
|
1568
|
+
geom_op = self.GEOMETRIC_HELPER
|
|
1569
|
+
jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
|
|
1570
|
+
|
|
1571
|
+
# recursively compile arguments
|
|
1539
1572
|
jax_prob = self._jax(arg_prob, info)
|
|
1540
1573
|
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
err |= (out_of_bounds * ERR)
|
|
1550
|
-
return sample, key, err
|
|
1551
|
-
|
|
1552
|
-
else:
|
|
1553
|
-
floor_op, jax_param = self._unwrap(
|
|
1554
|
-
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1555
|
-
|
|
1556
|
-
# reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
|
|
1557
|
-
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1558
|
-
prob, key, err = jax_prob(x, params, key)
|
|
1559
|
-
key, subkey = random.split(key)
|
|
1560
|
-
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1561
|
-
param = params.get(jax_param, None)
|
|
1562
|
-
sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
1563
|
-
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1564
|
-
err |= (out_of_bounds * ERR)
|
|
1565
|
-
return sample, key, err
|
|
1574
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1575
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1576
|
+
key, subkey = random.split(key)
|
|
1577
|
+
param = params.get(jax_param, None)
|
|
1578
|
+
sample = jax_geom(subkey, prob, param).astype(self.INT)
|
|
1579
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1580
|
+
err |= (out_of_bounds * ERR)
|
|
1581
|
+
return sample, key, err
|
|
1566
1582
|
|
|
1567
1583
|
return _jax_wrapped_distribution_geometric
|
|
1568
1584
|
|
|
@@ -1771,7 +1787,10 @@ class JaxRDDLCompiler:
|
|
|
1771
1787
|
# dispatch to sampling subroutine
|
|
1772
1788
|
key, subkey = random.split(key)
|
|
1773
1789
|
param = params.get(jax_param, None)
|
|
1774
|
-
sample
|
|
1790
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1791
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1792
|
+
jnp.all(prob >= 0),
|
|
1793
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1775
1794
|
error |= (out_of_bounds * ERR)
|
|
1776
1795
|
return sample, key, error
|
|
1777
1796
|
|
|
@@ -1804,7 +1823,10 @@ class JaxRDDLCompiler:
|
|
|
1804
1823
|
# dispatch to sampling subroutine
|
|
1805
1824
|
key, subkey = random.split(key)
|
|
1806
1825
|
param = params.get(jax_param, None)
|
|
1807
|
-
sample
|
|
1826
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1827
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1828
|
+
jnp.all(prob >= 0),
|
|
1829
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1808
1830
|
error |= (out_of_bounds * ERR)
|
|
1809
1831
|
return sample, key, error
|
|
1810
1832
|
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,11 +1,19 @@
|
|
|
1
|
+
from typing import Optional, Set
|
|
2
|
+
|
|
1
3
|
import jax
|
|
2
4
|
import jax.numpy as jnp
|
|
3
5
|
import jax.random as random
|
|
4
|
-
from typing import Optional, Set
|
|
5
6
|
|
|
6
7
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
7
8
|
|
|
8
9
|
|
|
10
|
+
# ===========================================================================
|
|
11
|
+
# LOGICAL COMPLEMENT
|
|
12
|
+
# - abstract class
|
|
13
|
+
# - standard complement
|
|
14
|
+
#
|
|
15
|
+
# ===========================================================================
|
|
16
|
+
|
|
9
17
|
class Complement:
|
|
10
18
|
'''Base class for approximate logical complement operations.'''
|
|
11
19
|
|
|
@@ -20,6 +28,13 @@ class StandardComplement(Complement):
|
|
|
20
28
|
return 1.0 - x
|
|
21
29
|
|
|
22
30
|
|
|
31
|
+
# ===========================================================================
|
|
32
|
+
# RELATIONAL OPERATIONS
|
|
33
|
+
# - abstract class
|
|
34
|
+
# - sigmoid comparison
|
|
35
|
+
#
|
|
36
|
+
# ===========================================================================
|
|
37
|
+
|
|
23
38
|
class Comparison:
|
|
24
39
|
'''Base class for approximate comparison operations.'''
|
|
25
40
|
|
|
@@ -44,7 +59,17 @@ class SigmoidComparison(Comparison):
|
|
|
44
59
|
|
|
45
60
|
def equal(self, x, y, param):
|
|
46
61
|
return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
|
|
47
|
-
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ===========================================================================
|
|
65
|
+
# TNORMS
|
|
66
|
+
# - abstract tnorm
|
|
67
|
+
# - product tnorm
|
|
68
|
+
# - Godel tnorm
|
|
69
|
+
# - Lukasiewicz tnorm
|
|
70
|
+
# - Yager(p) tnorm
|
|
71
|
+
#
|
|
72
|
+
# ===========================================================================
|
|
48
73
|
|
|
49
74
|
class TNorm:
|
|
50
75
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
@@ -86,8 +111,133 @@ class LukasiewiczTNorm(TNorm):
|
|
|
86
111
|
|
|
87
112
|
def norms(self, x, axis):
|
|
88
113
|
return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class YagerTNorm(TNorm):
|
|
117
|
+
'''Yager t-norm given by the expression
|
|
118
|
+
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
119
|
+
|
|
120
|
+
def __init__(self, p=2.0):
|
|
121
|
+
self.p = p
|
|
122
|
+
|
|
123
|
+
def norm(self, x, y):
|
|
124
|
+
base_x = jax.nn.relu(1.0 - x)
|
|
125
|
+
base_y = jax.nn.relu(1.0 - y)
|
|
126
|
+
arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
|
|
127
|
+
return jax.nn.relu(1.0 - arg)
|
|
128
|
+
|
|
129
|
+
def norms(self, x, axis):
|
|
130
|
+
base = jax.nn.relu(1.0 - x)
|
|
131
|
+
arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
|
|
132
|
+
return jax.nn.relu(1.0 - arg)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# ===========================================================================
|
|
136
|
+
# RANDOM SAMPLING
|
|
137
|
+
# - abstract sampler
|
|
138
|
+
# - Gumbel-softmax sampler
|
|
139
|
+
# - determinization
|
|
140
|
+
#
|
|
141
|
+
# ===========================================================================
|
|
142
|
+
|
|
143
|
+
class RandomSampling:
|
|
144
|
+
'''An abstract class that describes how discrete and non-reparameterizable
|
|
145
|
+
random variables are sampled.'''
|
|
146
|
+
|
|
147
|
+
def discrete(self, logic):
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
|
|
150
|
+
def bernoulli(self, logic):
|
|
151
|
+
jax_discrete, jax_param = self.discrete(logic)
|
|
152
|
+
|
|
153
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
154
|
+
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
155
|
+
sample = jax_discrete(key, prob, param)
|
|
156
|
+
return sample
|
|
157
|
+
|
|
158
|
+
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
159
|
+
|
|
160
|
+
def poisson(self, logic):
|
|
161
|
+
|
|
162
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, param):
|
|
163
|
+
return random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
164
|
+
|
|
165
|
+
return _jax_wrapped_calc_poisson_exact, None
|
|
166
|
+
|
|
167
|
+
def geometric(self, logic):
|
|
168
|
+
if logic.verbose:
|
|
169
|
+
raise_warning('Using the replacement rule: '
|
|
170
|
+
'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
|
|
171
|
+
|
|
172
|
+
jax_floor, jax_param = logic.floor()
|
|
173
|
+
|
|
174
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, param):
|
|
175
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
176
|
+
sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
177
|
+
return sample
|
|
178
|
+
|
|
179
|
+
return _jax_wrapped_calc_geometric_approx, jax_param
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class GumbelSoftmax(RandomSampling):
|
|
183
|
+
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
89
184
|
|
|
185
|
+
def discrete(self, logic):
|
|
186
|
+
if logic.verbose:
|
|
187
|
+
raise_warning('Using the replacement rule: '
|
|
188
|
+
'Discrete(p) --> Gumbel-softmax(p)')
|
|
189
|
+
|
|
190
|
+
jax_argmax, jax_param = logic.argmax()
|
|
191
|
+
|
|
192
|
+
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
193
|
+
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
194
|
+
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
195
|
+
sample = jax_argmax(sample, axis=-1, param=param)
|
|
196
|
+
return sample
|
|
197
|
+
|
|
198
|
+
return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class Determinization(RandomSampling):
|
|
202
|
+
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
90
203
|
|
|
204
|
+
def discrete(self, logic):
|
|
205
|
+
if logic.verbose:
|
|
206
|
+
raise_warning('Using the replacement rule: '
|
|
207
|
+
'Discrete(p) --> sum(i * p[i])')
|
|
208
|
+
|
|
209
|
+
def _jax_wrapped_calc_discrete_determinized(key, prob, param):
|
|
210
|
+
literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
|
|
211
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
212
|
+
return sample
|
|
213
|
+
|
|
214
|
+
return _jax_wrapped_calc_discrete_determinized, None
|
|
215
|
+
|
|
216
|
+
def poisson(self, logic):
|
|
217
|
+
if logic.verbose:
|
|
218
|
+
raise_warning('Using the replacement rule: Poisson(rate) --> rate')
|
|
219
|
+
|
|
220
|
+
def _jax_wrapped_calc_poisson_determinized(key, rate, param):
|
|
221
|
+
return rate
|
|
222
|
+
|
|
223
|
+
return _jax_wrapped_calc_poisson_determinized, None
|
|
224
|
+
|
|
225
|
+
def geometric(self, logic):
|
|
226
|
+
if logic.verbose:
|
|
227
|
+
raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
|
|
228
|
+
|
|
229
|
+
def _jax_wrapped_calc_geometric_determinized(key, prob, param):
|
|
230
|
+
sample = 1.0 / prob
|
|
231
|
+
return sample
|
|
232
|
+
|
|
233
|
+
return _jax_wrapped_calc_geometric_determinized, None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# ===========================================================================
|
|
237
|
+
# FUZZY LOGIC
|
|
238
|
+
#
|
|
239
|
+
# ===========================================================================
|
|
240
|
+
|
|
91
241
|
class FuzzyLogic:
|
|
92
242
|
'''A class representing fuzzy logic in JAX.
|
|
93
243
|
|
|
@@ -98,9 +248,10 @@ class FuzzyLogic:
|
|
|
98
248
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
99
249
|
complement: Complement=StandardComplement(),
|
|
100
250
|
comparison: Comparison=SigmoidComparison(),
|
|
251
|
+
sampling: RandomSampling=GumbelSoftmax(),
|
|
101
252
|
weight: float=10.0,
|
|
102
253
|
debias: Optional[Set[str]]=None,
|
|
103
|
-
eps: float=1e-
|
|
254
|
+
eps: float=1e-15,
|
|
104
255
|
verbose: bool=False,
|
|
105
256
|
use64bit: bool=False) -> None:
|
|
106
257
|
'''Creates a new fuzzy logic in Jax.
|
|
@@ -108,8 +259,8 @@ class FuzzyLogic:
|
|
|
108
259
|
:param tnorm: fuzzy operator for logical AND
|
|
109
260
|
:param complement: fuzzy operator for logical NOT
|
|
110
261
|
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
262
|
+
:param sampling: random sampling of non-reparameterizable distributions
|
|
111
263
|
:param weight: a sharpness parameter for sigmoid and softmax activations
|
|
112
|
-
:param error: an error parameter (e.g. floor) (smaller means better accuracy)
|
|
113
264
|
:param debias: which functions to de-bias approximate on forward pass
|
|
114
265
|
:param eps: small positive float to mitigate underflow
|
|
115
266
|
:param verbose: whether to dump replacements and other info to console
|
|
@@ -118,6 +269,7 @@ class FuzzyLogic:
|
|
|
118
269
|
self.tnorm = tnorm
|
|
119
270
|
self.complement = complement
|
|
120
271
|
self.comparison = comparison
|
|
272
|
+
self.sampling = sampling
|
|
121
273
|
self.weight = float(weight)
|
|
122
274
|
if debias is None:
|
|
123
275
|
debias = set()
|
|
@@ -142,10 +294,11 @@ class FuzzyLogic:
|
|
|
142
294
|
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
143
295
|
f' complement ={type(self.complement).__name__}\n'
|
|
144
296
|
f' comparison ={type(self.comparison).__name__}\n'
|
|
297
|
+
f' sampling ={type(self.sampling).__name__}\n'
|
|
145
298
|
f' sigmoid_weight={self.weight}\n'
|
|
146
299
|
f' cpfs_to_debias={self.debias}\n'
|
|
147
300
|
f' underflow_tol ={self.eps}\n'
|
|
148
|
-
f'
|
|
301
|
+
f' use_64_bit ={self.use64bit}')
|
|
149
302
|
|
|
150
303
|
# ===========================================================================
|
|
151
304
|
# logical operators
|
|
@@ -419,7 +572,7 @@ class FuzzyLogic:
|
|
|
419
572
|
# ===========================================================================
|
|
420
573
|
|
|
421
574
|
@staticmethod
|
|
422
|
-
def
|
|
575
|
+
def enumerate_literals(shape, axis):
|
|
423
576
|
literals = jnp.arange(shape[axis])
|
|
424
577
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
425
578
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -434,7 +587,7 @@ class FuzzyLogic:
|
|
|
434
587
|
debias = 'argmax' in self.debias
|
|
435
588
|
|
|
436
589
|
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
437
|
-
literals = FuzzyLogic.
|
|
590
|
+
literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
|
|
438
591
|
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
439
592
|
sample = jnp.sum(literals * soft_max, axis=axis)
|
|
440
593
|
if debias:
|
|
@@ -468,7 +621,7 @@ class FuzzyLogic:
|
|
|
468
621
|
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
469
622
|
sample = c * a + (1.0 - c) * b
|
|
470
623
|
if debias:
|
|
471
|
-
hard_sample = jnp.
|
|
624
|
+
hard_sample = jnp.where(c > 0.5, a, b)
|
|
472
625
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
473
626
|
return sample
|
|
474
627
|
|
|
@@ -483,7 +636,7 @@ class FuzzyLogic:
|
|
|
483
636
|
debias = 'switch' in self.debias
|
|
484
637
|
|
|
485
638
|
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
486
|
-
literals = FuzzyLogic.
|
|
639
|
+
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
487
640
|
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
488
641
|
proximity = -jnp.abs(pred - literals)
|
|
489
642
|
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
@@ -502,44 +655,24 @@ class FuzzyLogic:
|
|
|
502
655
|
# random variables
|
|
503
656
|
# ===========================================================================
|
|
504
657
|
|
|
505
|
-
def
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
return sample
|
|
509
|
-
|
|
658
|
+
def discrete(self):
|
|
659
|
+
return self.sampling.discrete(self)
|
|
660
|
+
|
|
510
661
|
def bernoulli(self):
|
|
511
|
-
|
|
512
|
-
raise_warning('Using the replacement rule: '
|
|
513
|
-
'Bernoulli(p) --> Gumbel-softmax(p)')
|
|
514
|
-
|
|
515
|
-
jax_gs = self._gumbel_softmax
|
|
516
|
-
jax_argmax, jax_param = self.argmax()
|
|
517
|
-
|
|
518
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
519
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
520
|
-
sample = jax_gs(key, prob)
|
|
521
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
522
|
-
return sample
|
|
523
|
-
|
|
524
|
-
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
662
|
+
return self.sampling.bernoulli(self)
|
|
525
663
|
|
|
526
|
-
def
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
jax_argmax, jax_param = self.argmax()
|
|
533
|
-
|
|
534
|
-
def _jax_wrapped_calc_discrete_approx(key, prob, param):
|
|
535
|
-
sample = jax_gs(key, prob)
|
|
536
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
537
|
-
return sample
|
|
538
|
-
|
|
539
|
-
return _jax_wrapped_calc_discrete_approx, jax_param
|
|
540
|
-
|
|
664
|
+
def poisson(self):
|
|
665
|
+
return self.sampling.poisson(self)
|
|
666
|
+
|
|
667
|
+
def geometric(self):
|
|
668
|
+
return self.sampling.geometric(self)
|
|
669
|
+
|
|
541
670
|
|
|
671
|
+
# ===========================================================================
|
|
542
672
|
# UNIT TESTS
|
|
673
|
+
#
|
|
674
|
+
# ===========================================================================
|
|
675
|
+
|
|
543
676
|
logic = FuzzyLogic()
|
|
544
677
|
w = 100.0
|
|
545
678
|
|
|
@@ -598,13 +731,14 @@ def _test_random():
|
|
|
598
731
|
key = random.PRNGKey(42)
|
|
599
732
|
_bernoulli, _ = logic.bernoulli()
|
|
600
733
|
_discrete, _ = logic.discrete()
|
|
734
|
+
_geometric, _ = logic.geometric()
|
|
601
735
|
|
|
602
736
|
def bern(n):
|
|
603
737
|
prob = jnp.asarray([0.3] * n)
|
|
604
738
|
sample = _bernoulli(key, prob, w)
|
|
605
739
|
return sample
|
|
606
740
|
|
|
607
|
-
samples = bern(
|
|
741
|
+
samples = bern(50000)
|
|
608
742
|
print(jnp.mean(samples))
|
|
609
743
|
|
|
610
744
|
def disc(n):
|
|
@@ -613,10 +747,18 @@ def _test_random():
|
|
|
613
747
|
sample = _discrete(key, prob, w)
|
|
614
748
|
return sample
|
|
615
749
|
|
|
616
|
-
samples = disc(
|
|
750
|
+
samples = disc(50000)
|
|
617
751
|
samples = jnp.round(samples)
|
|
618
752
|
print([jnp.mean(samples == i) for i in range(3)])
|
|
619
|
-
|
|
753
|
+
|
|
754
|
+
def geom(n):
|
|
755
|
+
prob = jnp.asarray([0.3] * n)
|
|
756
|
+
sample = _geometric(key, prob, w)
|
|
757
|
+
return sample
|
|
758
|
+
|
|
759
|
+
samples = geom(50000)
|
|
760
|
+
print(jnp.mean(samples))
|
|
761
|
+
|
|
620
762
|
|
|
621
763
|
def _test_rounding():
|
|
622
764
|
print('testing rounding')
|